aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/swinir.py (renamed from swinir.py)28
1 files changed, 23 insertions, 5 deletions
diff --git a/swinir.py b/modules/swinir.py
index cb2bbe3d..6c7f0a2d 100644
--- a/swinir.py
+++ b/modules/swinir.py
@@ -12,7 +12,13 @@ import modules.images
from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(cmd_opts.esrgan_models_path))):
+def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4):
+
+ try:
+ modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin"))
+ except Exception:
+ print(f"Error loading ESRGAN model", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
if not large_model:
# use 'nearest+conv' to avoid block artifacts
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
@@ -26,12 +32,16 @@ def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(c
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
pretrained_model = torch.load(model_path)
- model.load_state_dict(pretrained_model, strict=True)
+ model.load_state_dict(pretrained_model["params_ema"], strict=True)
return model.half().to(device)
def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4):
- img = cv2.imread(img, cv2.IMREAD_COLOR).astype(np.float16) / 255.
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.moveaxis(img, 2, 0) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(device)
model = load_model()
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
@@ -45,7 +55,7 @@ def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, w
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return output
+ return Image.fromarray(output, 'RGB')
def inference(img, model, tile, tile_overlap, window_size, scale):
@@ -71,4 +81,12 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
output = E.div_(W)
- return output \ No newline at end of file
+ return output
+
+class UpscalerSwin(modules.images.Upscaler):
+ def __init__(self, title):
+ self.name = title
+
+ def do_upscale(self, img):
+ img = upscale(img)
+ return img