aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/SwinIR
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-30 22:53:49 +0200
committerAarni Koskela <akx@iki.fi>2023-12-31 01:13:30 +0200
commit6f86b62a1be7993073ba3a789d522e0b8870605a (patch)
treeae0e5e9df369eb1cefa41ee76eb0e56fe945d192 /extensions-builtin/SwinIR
parentce21840a042b9454a136372ab2971c1f21ec51e0 (diff)
Deduplicate tiled inference code from SwinIR/ScuNET
Diffstat (limited to 'extensions-builtin/SwinIR')
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py57
1 files changed, 5 insertions, 52 deletions
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 95c7ec64..8a555c79 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -4,11 +4,11 @@ import sys
import numpy as np
import torch
from PIL import Image
-from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts, state
+from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
@@ -110,14 +110,14 @@ def upscale(
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(
+ output = tiled_upscale_2(
img,
model,
- tile=tile,
+ tile_size=tile,
tile_overlap=tile_overlap,
- window_size=window_size,
scale=scale,
device=device,
+ desc="SwinIR tiles",
)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -129,53 +129,6 @@ def upscale(
return Image.fromarray(output, "RGB")
-def inference(
- img,
- model,
- *,
- tile: int,
- tile_overlap: int,
- window_size: int,
- scale: int,
- device,
-):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
- for h_idx in h_idx_list:
- if state.interrupted or state.skipped:
- break
-
- for w_idx in w_idx_list:
- if state.interrupted or state.skipped:
- break
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
-
def on_ui_settings():
import gradio as gr