aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/SwinIR/scripts/swinir_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/SwinIR/scripts/swinir_model.py')
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py83
1 files changed, 49 insertions, 34 deletions
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 1c7bf325..ae0d0e6a 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,34 +1,35 @@
-import os
+import sys
+import platform
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR
+from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData
+SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
+ self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
+ self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
self.name = "SwinIR"
- self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
- "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
- "-L_x4_GAN.pth "
+ self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
- if "http" in model:
+ if model.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(model)
@@ -37,42 +38,54 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img, model_file):
- model = self.load_model(model_file)
- if model is None:
- return img
- model = model.to(device_swinir, dtype=devices.dtype)
+ use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
+ and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
+ current_config = (model_file, opts.SWIN_tile)
+
+ if use_compile and self._cached_model_config == current_config:
+ model = self._cached_model
+ else:
+ self._cached_model = None
+ try:
+ model = self.load_model(model_file)
+ except Exception as e:
+ print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
+ return img
+ model = model.to(device_swinir, dtype=devices.dtype)
+ if use_compile:
+ model = torch.compile(model)
+ self._cached_model = model
+ self._cached_model_config = current_config
img = upscale(img, model)
- try:
- torch.cuda.empty_cache()
- except Exception:
- pass
+ devices.torch_gc()
return img
def load_model(self, path, scale=4):
- if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
+ if path.startswith("http"):
+ filename = modelloader.load_file_from_url(
+ url=path,
+ model_dir=self.model_download_path,
+ file_name=f"{self.model_name.replace(' ', '_')}.pth",
+ )
else:
filename = path
- if filename is None or not os.path.exists(filename):
- return None
if filename.endswith(".v2.pth"):
- model = net2(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
+ model = Swin2SR(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="1conv",
)
params = None
else:
- model = net(
+ model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
@@ -172,6 +185,8 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+ if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
+ shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings)