aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py48
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py62
-rw-r--r--modules/upscaler_utils.py89
3 files changed, 87 insertions, 112 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index f799cb76..fe5e5a19 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,13 +1,9 @@
import sys
import PIL.Image
-import numpy as np
-import torch
import modules.upscaler
-from modules import devices, modelloader, script_callbacks, errors
-from modules.shared import opts
-from modules.upscaler_utils import tiled_upscale_2
+from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
self.scalers = scalers
def do_upscale(self, img: PIL.Image.Image, selected_file):
-
devices.torch_gc()
-
try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
- device = devices.get_device_for('scunet')
- tile = opts.SCUNET_tile
- h, w = img.height, img.width
- np_img = np.array(img)
- np_img = np_img[:, :, ::-1] # RGB to BGR
- np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
- torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
-
- if tile > h or tile > w:
- _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
- _img[:, :, :h, :w] = torch_img # pad image
- torch_img = _img
-
- with torch.no_grad():
- torch_output = tiled_upscale_2(
- torch_img,
- model,
- tile_size=opts.SCUNET_tile,
- tile_overlap=opts.SCUNET_tile_overlap,
- scale=1,
- device=devices.get_device_for('scunet'),
- desc="ScuNET tiles",
- ).squeeze(0)
- torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
- np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
- del torch_img, torch_output
+ img = upscaler_utils.upscale_2(
+ img,
+ model,
+ tile_size=shared.opts.SCUNET_tile,
+ tile_overlap=shared.opts.SCUNET_tile_overlap,
+ scale=1, # ScuNET is a denoising model, not an upscaler
+ desc='ScuNET',
+ )
devices.torch_gc()
-
- output = np_output.transpose((1, 2, 0)) # CHW to HWC
- output = output[:, :, ::-1] # BGR to RGB
- return PIL.Image.fromarray((output * 255).astype(np.uint8))
+ return img
def load_model(self, path: str):
device = devices.get_device_for('scunet')
@@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def on_ui_settings():
import gradio as gr
- from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 8a555c79..bc427fea 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,14 +1,10 @@
import logging
import sys
-import numpy as np
-import torch
from PIL import Image
-from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts
+from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.upscaler import Upscaler, UpscalerData
-from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
@@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
- current_config = (model_file, opts.SWIN_tile)
-
- device = self._get_device()
+ current_config = (model_file, shared.opts.SWIN_tile)
if self._cached_model_config == current_config:
model = self._cached_model
@@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler):
self._cached_model = model
self._cached_model_config = current_config
- img = upscale(
+ img = upscaler_utils.upscale_2(
img,
model,
- tile=opts.SWIN_tile,
- tile_overlap=opts.SWIN_tile_overlap,
- device=device,
+ tile_size=shared.opts.SWIN_tile,
+ tile_overlap=shared.opts.SWIN_tile_overlap,
+ scale=4, # TODO: This was hard-coded before too...
+ desc="SwinIR",
)
devices.torch_gc()
return img
@@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler):
dtype=devices.dtype,
expected_architecture="SwinIR",
)
- if getattr(opts, 'SWIN_torch_compile', False):
+ if getattr(shared.opts, 'SWIN_torch_compile', False):
try:
model_descriptor.model.compile()
except Exception:
@@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler):
return devices.get_device_for('swinir')
-def upscale(
- img,
- model,
- *,
- tile: int,
- tile_overlap: int,
- window_size=8,
- scale=4,
- device,
-):
-
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device, dtype=devices.dtype)
- with torch.no_grad(), devices.autocast():
- _, _, h_old, w_old = img.size()
- h_pad = (h_old // window_size + 1) * window_size - h_old
- w_pad = (w_old // window_size + 1) * window_size - w_old
- img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
- img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = tiled_upscale_2(
- img,
- model,
- tile_size=tile,
- tile_overlap=tile_overlap,
- scale=scale,
- device=device,
- desc="SwinIR tiles",
- )
- output = output[..., : h_old * scale, : w_old * scale]
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
- if output.ndim == 3:
- output = np.transpose(
- output[[2, 1, 0], :, :], (1, 2, 0)
- ) # CHW-RGB to HCW-BGR
- output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return Image.fromarray(output, "RGB")
-
-
def on_ui_settings():
import gradio as gr
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 9379f512..e4c63f09 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -11,23 +11,40 @@ from modules import images, shared, torch_utils
logger = logging.getLogger(__name__)
-def upscale_without_tiling(model, img: Image.Image):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
-
+def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
+ img = np.array(img.convert("RGB"))
+ img = img[:, :, ::-1] # flip RGB to BGR
+ img = np.transpose(img, (2, 0, 1)) # HWC to CHW
+ img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
+ return torch.from_numpy(img)
+
+
+def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
+ if tensor.ndim == 4:
+ # If we're given a tensor with a batch dimension, squeeze it out
+ # (but only if it's a batch of size 1).
+ if tensor.shape[0] != 1:
+ raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
+ tensor = tensor.squeeze(0)
+ assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
+ # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
+ arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
+ arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
+ arr = arr.astype(np.uint8)
+ arr = arr[:, :, ::-1] # flip BGR to RGB
+ return Image.fromarray(arr, "RGB")
+
+
+def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
+ """
+ Upscale a given PIL image using the given model.
+ """
param = torch_utils.get_param(model)
- img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad():
- output = model(img)
-
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
+ tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
+ tensor = tensor.to(device=param.device, dtype=param.dtype)
+ return torch_bgr_to_pil_image(model(tensor))
def upscale_with_model(
@@ -40,7 +57,7 @@ def upscale_with_model(
) -> Image.Image:
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img)
- output = upscale_without_tiling(model, img)
+ output = upscale_pil_patch(model, img)
logger.debug("=> %s", output)
return output
@@ -52,7 +69,7 @@ def upscale_with_model(
newrow = []
for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile)
- output = upscale_without_tiling(model, tile)
+ output = upscale_pil_patch(model, tile)
scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output])
@@ -71,19 +88,22 @@ def upscale_with_model(
def tiled_upscale_2(
- img,
+ img: torch.Tensor,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
- device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
+
+ # Grab the device the model is on, and use it.
+ device = torch_utils.get_param(model).device
+
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)
@@ -100,7 +120,8 @@ def tiled_upscale_2(
h * scale,
w * scale,
device=device,
- ).type_as(img)
+ dtype=img.dtype,
+ )
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
@@ -112,11 +133,13 @@ def tiled_upscale_2(
if shared.state.interrupted or shared.state.skipped:
break
+ # Only move this patch to the device if it's not already there.
in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
- ]
+ ].to(device=device)
+
out_patch = model(in_patch)
result[
@@ -138,3 +161,29 @@ def tiled_upscale_2(
output = result.div_(weights)
return output
+
+
+def upscale_2(
+ img: Image.Image,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ desc: str,
+):
+ """
+ Convenience wrapper around `tiled_upscale_2` that handles PIL images.
+ """
+ tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
+
+ with torch.no_grad():
+ output = tiled_upscale_2(
+ tensor,
+ model,
+ tile_size=tile_size,
+ tile_overlap=tile_overlap,
+ scale=scale,
+ desc=desc,
+ )
+ return torch_bgr_to_pil_image(output)