aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/SwinIR/scripts/swinir_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/SwinIR/scripts/swinir_model.py')
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py127
1 files changed, 61 insertions, 66 deletions
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index ae0d0e6a..aae159af 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,5 +1,5 @@
+import logging
import sys
-import platform
import numpy as np
import torch
@@ -8,13 +8,11 @@ from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
-from swinir_model_arch import SwinIR
-from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
-device_swinir = devices.get_device_for('swinir')
+logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler):
@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data)
self.scalers = scalers
- def do_upscale(self, img, model_file):
- use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
- and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
+ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
current_config = (model_file, opts.SWIN_tile)
- if use_compile and self._cached_model_config == current_config:
+ device = self._get_device()
+
+ if self._cached_model_config == current_config:
model = self._cached_model
else:
- self._cached_model = None
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
- model = model.to(device_swinir, dtype=devices.dtype)
- if use_compile:
- model = torch.compile(model)
- self._cached_model = model
- self._cached_model_config = current_config
- img = upscale(img, model)
+ self._cached_model = model
+ self._cached_model_config = current_config
+
+ img = upscale(
+ img,
+ model,
+ tile=opts.SWIN_tile,
+ tile_overlap=opts.SWIN_tile_overlap,
+ device=device,
+ )
devices.torch_gc()
return img
@@ -69,69 +70,55 @@ class UpscalerSwinIR(Upscaler):
)
else:
filename = path
- if filename.endswith(".v2.pth"):
- model = Swin2SR(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
- )
- params = None
- else:
- model = SwinIR(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
- params = "params_ema"
- pretrained_model = torch.load(filename)
- if params is not None:
- model.load_state_dict(pretrained_model[params], strict=True)
- else:
- model.load_state_dict(pretrained_model, strict=True)
+ model = modelloader.load_spandrel_model(
+ filename,
+ device=self._get_device(),
+ dtype=devices.dtype,
+ expected_architecture="SwinIR",
+ )
+ if getattr(opts, 'SWIN_torch_compile', False):
+ try:
+ model = torch.compile(model)
+ except Exception:
+ logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
return model
+ def _get_device(self):
+ return devices.get_device_for('swinir')
+
def upscale(
- img,
- model,
- tile=None,
- tile_overlap=None,
- window_size=8,
- scale=4,
+ img,
+ model,
+ *,
+ tile: int,
+ tile_overlap: int,
+ window_size=8,
+ scale=4,
+ device,
):
- tile = tile or opts.SWIN_tile
- tile_overlap = tile_overlap or opts.SWIN_tile_overlap
-
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
+ img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(img, model, tile, tile_overlap, window_size, scale)
+ output = inference(
+ img,
+ model,
+ tile=tile,
+ tile_overlap=tile_overlap,
+ window_size=window_size,
+ scale=scale,
+ device=device,
+ )
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
@@ -142,7 +129,16 @@ def upscale(
return Image.fromarray(output, "RGB")
-def inference(img, model, tile, tile_overlap, window_size, scale):
+def inference(
+ img,
+ model,
+ *,
+ tile: int,
+ tile_overlap: int,
+ window_size: int,
+ scale: int,
+ device,
+):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
@@ -152,8 +148,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
+ W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
@@ -185,8 +181,7 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
- if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
- shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
+ shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings)