From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- extensions-builtin/SwinIR/scripts/swinir_model.py | 62 +++-------------------- 1 file changed, 8 insertions(+), 54 deletions(-) (limited to 'extensions-builtin/SwinIR') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 8a555c79..bc427fea 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,14 +1,10 @@ import logging import sys -import numpy as np -import torch from PIL import Image -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData -from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: - current_config = (model_file, opts.SWIN_tile) - - device = self._get_device() + current_config = (model_file, shared.opts.SWIN_tile) if self._cached_model_config == current_config: model = self._cached_model @@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler): self._cached_model = model self._cached_model_config = current_config - img = upscale( + img = upscaler_utils.upscale_2( img, model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - device=device, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=4, # TODO: This was hard-coded before too... + desc="SwinIR", ) devices.torch_gc() return img @@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler): dtype=devices.dtype, expected_architecture="SwinIR", ) - if getattr(opts, 'SWIN_torch_compile', False): + if getattr(shared.opts, 'SWIN_torch_compile', False): try: model_descriptor.model.compile() except Exception: @@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler): return devices.get_device_for('swinir') -def upscale( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size=8, - scale=4, - device, -): - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = tiled_upscale_2( - img, - model, - tile_size=tile, - tile_overlap=tile_overlap, - scale=scale, - device=device, - desc="SwinIR tiles", - ) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - def on_ui_settings(): import gradio as gr -- cgit v1.2.1