aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/SwinIR
diff options
context:
space:
mode:
authorcatboxanon <122327233+catboxanon@users.noreply.github.com>2023-01-23 21:50:59 -0500
committerGitHub <noreply@github.com>2023-01-23 21:50:59 -0500
commitf99352582084890b9167c1bf8699865bea0cef5f (patch)
tree786d85d6e725a1844996fcd01cc9a7ca6f60fc88 /extensions-builtin/SwinIR
parent5c1cb9263f980641007088a37360fcab01761d37 (diff)
Make SwinIR interruptible
Diffstat (limited to 'extensions-builtin/SwinIR')
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 9a74b253..3479760a 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts
+from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
+ if state.interrupted:
+ break
+
for w_idx in w_idx_list:
+ if state.interrupted:
+ break
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)