aboutsummaryrefslogtreecommitdiff
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authord8ahazard <d8ahazard@gmail.com>2022-09-29 17:46:23 -0500
committerd8ahazard <d8ahazard@gmail.com>2022-09-29 17:46:23 -0500
commit0dce0df1ee63b2f158805c1a1f1a3743cc4a104b (patch)
treedfcec33656d06835e71961b117b63e510cb9bff2 /modules/esrgan_model.py
parent31ad536c331df14dd785bfd2a1f93f91a8f7839e (diff)
Holy $hit.
Yep. Fix gfpgan_model_arch requirement(s). Add Upscaler base class, move from images. Add a lot of methods to Upscaler. Re-work all the child upscalers to be proper classes. Add BSRGAN scaler. Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff. Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated. Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size. Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size. Add typehints for IDE sanity. PEP-8 improvements. Moar.
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py227
1 files changed, 110 insertions, 117 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 5e10c49c..ce841aa4 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
import numpy as np
import torch
@@ -8,93 +6,119 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-import modules.images
-from modules import shared
-from modules import shared, modelloader
+from modules import shared, modelloader, images
from modules.devices import has_mps
from modules.paths import models_path
+from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-model_dir = "ESRGAN"
-model_path = os.path.join(models_path, model_dir)
-model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
-model_name = "ESRGAN_x4"
-
-
-def load_model(path: str, name: str):
- global model_path
- global model_url
- global model_dir
- global model_name
- if "http" in path:
- filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print("Unable to load %s from %s" % (model_dir, filename))
- return None
- print("Loading %s from %s" % (model_dir, filename))
- # this code is adapted from https://github.com/xinntao/ESRGAN
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
-
- if 'conv_first.weight' in pretrained_net:
- crt_model.load_state_dict(pretrained_net)
- return crt_model
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+ self.model_name = "ESRGAN 4x"
+ self.scalers = []
+ self.user_path = dirname
+ self.model_path = os.path.join(models_path, self.name)
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ print(f"File: {file}")
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+
+ scaler_data = UpscalerData(name, file, self, 4)
+ print(f"ESRGAN: Adding scaler {name}")
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(shared.device)
+ img = esrgan_upscale(model, img)
+ return img
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="%s.pth" % self.model_name,
+ progress=True)
else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- crt_model.load_state_dict(crt_net)
- crt_model.eval()
- return crt_model
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_path, filename))
+ return None
+ # this code is adapted from https://github.com/xinntao/ESRGAN
+ pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+
+ if 'conv_first.weight' in pretrained_net:
+ crt_model.load_state_dict(pretrained_net)
+ return crt_model
+
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
+ "params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
+ crt_net = crt_model.state_dict()
+ load_net_clean = {}
+ for k, v in pretrained_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ pretrained_net = load_net_clean
+
+ tbd = []
+ for k, v in crt_net.items():
+ tbd.append(k)
+
+ # directly copy
+ for k, v in crt_net.items():
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
+ crt_net[k] = pretrained_net[k]
+ tbd.remove(k)
+
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
+
+ for k in tbd.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[k] = pretrained_net[ori_k]
+ tbd.remove(k)
+
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+
+ crt_model.load_state_dict(crt_net)
+ crt_model.eval()
+ return crt_model
+
def upscale_without_tiling(model, img):
img = np.array(img)
@@ -115,7 +139,7 @@ def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img)
- grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = []
scale_factor = 1
@@ -130,38 +154,7 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
- newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = modules.images.combine_grid(newgrid)
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
+ grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
return output
-
-
-class UpscalerESRGAN(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.filename = filename
-
- def do_upscale(self, img):
- model = load_model(self.filename, self.name)
- if model is None:
- return img
- model.to(shared.device)
- img = esrgan_upscale(model, img)
- return img
-
-
-def setup_model(dirname):
- global model_path
- global model_name
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
- if len(model_paths) == 0:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
- for file in model_paths:
- name = modelloader.friendly_name(file)
- try:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
- except Exception:
- print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)