From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- extensions-builtin/ScuNET/scripts/scunet_model.py | 48 +++--------- extensions-builtin/SwinIR/scripts/swinir_model.py | 62 ++-------------- modules/upscaler_utils.py | 89 ++++++++++++++++++----- 3 files changed, 87 insertions(+), 112 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index f799cb76..fe5e5a19 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,13 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from modules.shared import opts -from modules.upscaler_utils import tiled_upscale_2 +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler): self.scalers = scalers def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - with torch.no_grad(): - torch_output = tiled_upscale_2( - torch_img, - model, - tile_size=opts.SCUNET_tile, - tile_overlap=opts.SCUNET_tile_overlap, - scale=1, - device=devices.get_device_for('scunet'), - desc="ScuNET tiles", - ).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 8a555c79..bc427fea 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,14 +1,10 @@ import logging import sys -import numpy as np -import torch from PIL import Image -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData -from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: - current_config = (model_file, opts.SWIN_tile) - - device = self._get_device() + current_config = (model_file, shared.opts.SWIN_tile) if self._cached_model_config == current_config: model = self._cached_model @@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler): self._cached_model = model self._cached_model_config = current_config - img = upscale( + img = upscaler_utils.upscale_2( img, model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - device=device, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=4, # TODO: This was hard-coded before too... + desc="SwinIR", ) devices.torch_gc() return img @@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler): dtype=devices.dtype, expected_architecture="SwinIR", ) - if getattr(opts, 'SWIN_torch_compile', False): + if getattr(shared.opts, 'SWIN_torch_compile', False): try: model_descriptor.model.compile() except Exception: @@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler): return devices.get_device_for('swinir') -def upscale( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size=8, - scale=4, - device, -): - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = tiled_upscale_2( - img, - model, - tile_size=tile, - tile_overlap=tile_overlap, - scale=scale, - device=device, - desc="SwinIR tiles", - ) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 9379f512..e4c63f09 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -11,23 +11,40 @@ from modules import images, shared, torch_utils logger = logging.getLogger(__name__) -def upscale_without_tiling(model, img: Image.Image): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - +def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: + img = np.array(img.convert("RGB")) + img = img[:, :, ::-1] # flip RGB to BGR + img = np.transpose(img, (2, 0, 1)) # HWC to CHW + img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] + return torch.from_numpy(img) + + +def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + # If we're given a tensor with a batch dimension, squeeze it out + # (but only if it's a batch of size 1). + if tensor.shape[0] != 1: + raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") + tensor = tensor.squeeze(0) + assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" + # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? + arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp + arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale + arr = arr.astype(np.uint8) + arr = arr[:, :, ::-1] # flip BGR to RGB + return Image.fromarray(arr, "RGB") + + +def upscale_pil_patch(model, img: Image.Image) -> Image.Image: + """ + Upscale a given PIL image using the given model. + """ param = torch_utils.get_param(model) - img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): - output = model(img) - - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension + tensor = tensor.to(device=param.device, dtype=param.dtype) + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( @@ -40,7 +57,7 @@ def upscale_with_model( ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) - output = upscale_without_tiling(model, img) + output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output @@ -52,7 +69,7 @@ def upscale_with_model( newrow = [] for x, w, tile in row: logger.debug("Tile (%d, %d) %s...", x, y, tile) - output = upscale_without_tiling(model, tile) + output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) @@ -71,19 +88,22 @@ def upscale_with_model( def tiled_upscale_2( - img, + img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, - device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. + + # Grab the device the model is on, and use it. + device = torch_utils.get_param(model).device + b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -100,7 +120,8 @@ def tiled_upscale_2( h * scale, w * scale, device=device, - ).type_as(img) + dtype=img.dtype, + ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: @@ -112,11 +133,13 @@ def tiled_upscale_2( if shared.state.interrupted or shared.state.skipped: break + # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, - ] + ].to(device=device) + out_patch = model(in_patch) result[ @@ -138,3 +161,29 @@ def tiled_upscale_2( output = result.div_(weights) return output + + +def upscale_2( + img: Image.Image, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + desc: str, +): + """ + Convenience wrapper around `tiled_upscale_2` that handles PIL images. + """ + tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + + with torch.no_grad(): + output = tiled_upscale_2( + tensor, + model, + tile_size=tile_size, + tile_overlap=tile_overlap, + scale=scale, + desc=desc, + ) + return torch_bgr_to_pil_image(output) -- cgit v1.2.1