From 121ed7d36febe94995774973b5edc1ba2ba84aad Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Sat, 1 Oct 2022 14:04:20 -0400 Subject: Add progress bar for SwinIR in cmd I do not know how to add them to the UI... --- modules/swinir_model.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) (limited to 'modules/swinir_model.py') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 41fda5a7..9bd454c6 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -5,6 +5,7 @@ import numpy as np import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url +from tqdm import tqdm from modules import modelloader from modules.paths import models_path @@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale): E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) W = torch.zeros_like(E, dtype=torch.half, device=device) - for h_idx in h_idx_list: - for w_idx in w_idx_list: - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) output = E.div_(W) return output -- cgit v1.2.1 From 0609ce06c0778536cb368ac3867292f87c6d9fc7 Mon Sep 17 00:00:00 2001 From: Milly Date: Fri, 7 Oct 2022 03:36:08 +0900 Subject: Removed duplicate definition model_path --- modules/swinir_model.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules/swinir_model.py') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 9bd454c6..fbd11f84 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -8,7 +8,6 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader -from modules.paths import models_path from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net from modules.upscaler import Upscaler, UpscalerData @@ -25,7 +24,6 @@ class UpscalerSwinIR(Upscaler): "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ "-L_x4_GAN.pth " self.model_name = "SwinIR 4x" - self.model_path = os.path.join(models_path, self.name) self.user_path = dirname super().__init__() scalers = [] -- cgit v1.2.1 From ed769977f0d0f201d8e361d365102f18775fc62c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 9 Oct 2022 18:56:59 +0300 Subject: add swinir v2 support --- modules/swinir_model.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) (limited to 'modules/swinir_model.py') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index fbd11f84..baa02e3d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -10,6 +10,7 @@ from tqdm import tqdm from modules import modelloader from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net +from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData precision_scope = ( @@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler): filename = path if filename is None or not os.path.exists(filename): return None - model = net( + if filename.endswith(".v2.pth"): + model = net2( upscale=scale, in_chans=3, img_size=64, window_size=8, img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler="nearest+conv", - resi_connection="3conv", - ) + resi_connection="1conv", + ) + params = None + else: + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + params = "params_ema" pretrained_model = torch.load(filename) - model.load_state_dict(pretrained_model["params_ema"], strict=True) + if params is not None: + model.load_state_dict(pretrained_model[params], strict=True) + else: + model.load_state_dict(pretrained_model, strict=True) if not cmd_opts.no_half: model = model.half() return model -- cgit v1.2.1