aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/upscaler_utils.py89
1 files changed, 69 insertions, 20 deletions
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 9379f512..e4c63f09 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -11,23 +11,40 @@ from modules import images, shared, torch_utils
logger = logging.getLogger(__name__)
-def upscale_without_tiling(model, img: Image.Image):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
-
+def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
+ img = np.array(img.convert("RGB"))
+ img = img[:, :, ::-1] # flip RGB to BGR
+ img = np.transpose(img, (2, 0, 1)) # HWC to CHW
+ img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
+ return torch.from_numpy(img)
+
+
+def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
+ if tensor.ndim == 4:
+ # If we're given a tensor with a batch dimension, squeeze it out
+ # (but only if it's a batch of size 1).
+ if tensor.shape[0] != 1:
+ raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
+ tensor = tensor.squeeze(0)
+ assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
+ # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
+ arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
+ arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
+ arr = arr.astype(np.uint8)
+ arr = arr[:, :, ::-1] # flip BGR to RGB
+ return Image.fromarray(arr, "RGB")
+
+
+def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
+ """
+ Upscale a given PIL image using the given model.
+ """
param = torch_utils.get_param(model)
- img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad():
- output = model(img)
-
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
+ tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
+ tensor = tensor.to(device=param.device, dtype=param.dtype)
+ return torch_bgr_to_pil_image(model(tensor))
def upscale_with_model(
@@ -40,7 +57,7 @@ def upscale_with_model(
) -> Image.Image:
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img)
- output = upscale_without_tiling(model, img)
+ output = upscale_pil_patch(model, img)
logger.debug("=> %s", output)
return output
@@ -52,7 +69,7 @@ def upscale_with_model(
newrow = []
for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile)
- output = upscale_without_tiling(model, tile)
+ output = upscale_pil_patch(model, tile)
scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output])
@@ -71,19 +88,22 @@ def upscale_with_model(
def tiled_upscale_2(
- img,
+ img: torch.Tensor,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
- device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
+
+ # Grab the device the model is on, and use it.
+ device = torch_utils.get_param(model).device
+
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)
@@ -100,7 +120,8 @@ def tiled_upscale_2(
h * scale,
w * scale,
device=device,
- ).type_as(img)
+ dtype=img.dtype,
+ )
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
@@ -112,11 +133,13 @@ def tiled_upscale_2(
if shared.state.interrupted or shared.state.skipped:
break
+ # Only move this patch to the device if it's not already there.
in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
- ]
+ ].to(device=device)
+
out_patch = model(in_patch)
result[
@@ -138,3 +161,29 @@ def tiled_upscale_2(
output = result.div_(weights)
return output
+
+
+def upscale_2(
+ img: Image.Image,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ desc: str,
+):
+ """
+ Convenience wrapper around `tiled_upscale_2` that handles PIL images.
+ """
+ tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
+
+ with torch.no_grad():
+ output = tiled_upscale_2(
+ tensor,
+ model,
+ tile_size=tile_size,
+ tile_overlap=tile_overlap,
+ scale=scale,
+ desc=desc,
+ )
+ return torch_bgr_to_pil_image(output)