aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-08 23:05:17 +0300
committerGitHub <noreply@github.com>2023-07-08 23:05:17 +0300
commitbcb6ad5fab6fb59fc79c8b6d94699cbabec34297 (patch)
treeb04a3a7f3799dd8218a3396760d848ba0cd1ccbe
parent7dcdf81b841a2df38ed5043408263dcc6a3426a3 (diff)
parent44d66daaad3dae283a85329020d1345d08189e32 (diff)
Merge pull request #11696 from WuSiYu/feat_SWIN_torch_compile
feat: add option SWIN_torch_compile to accelerate SwinIR upscale
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index c2c2a43c..ae0d0e6a 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,4 +1,5 @@
import sys
+import platform
import numpy as np
import torch
@@ -18,6 +19,8 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
+ self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
+ self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
self.name = "SwinIR"
self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
@@ -35,12 +38,24 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img, model_file):
- try:
- model = self.load_model(model_file)
- except Exception as e:
- print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
- return img
- model = model.to(device_swinir, dtype=devices.dtype)
+ use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
+ and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
+ current_config = (model_file, opts.SWIN_tile)
+
+ if use_compile and self._cached_model_config == current_config:
+ model = self._cached_model
+ else:
+ self._cached_model = None
+ try:
+ model = self.load_model(model_file)
+ except Exception as e:
+ print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
+ return img
+ model = model.to(device_swinir, dtype=devices.dtype)
+ if use_compile:
+ model = torch.compile(model)
+ self._cached_model = model
+ self._cached_model_config = current_config
img = upscale(img, model)
devices.torch_gc()
return img
@@ -170,6 +185,8 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+ if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
+ shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings)