diff options
Diffstat (limited to 'modules/realesrgan_model.py')
-rw-r--r-- | modules/realesrgan_model.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 5a6666a3..e480887f 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -4,6 +4,7 @@ from collections import namedtuple import numpy as np
from PIL import Image
+import modules.images
from modules.shared import cmd_opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@@ -12,6 +13,17 @@ realesrgan_models = [] have_realesrgan = False
RealESRGANer_constructor = None
+
+class UpscalerRealESRGAN(modules.images.Upscaler):
+ def __init__(self, upscaling, model_index):
+ self.upscaling = upscaling
+ self.model_index = model_index
+ self.name = realesrgan_models[model_index].name
+
+ def do_upscale(self, img):
+ return upscale_with_realesrgan(img, self.upscaling, self.model_index)
+
+
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
@@ -42,6 +54,9 @@ def setup_realesrgan(): have_realesrgan = True
RealESRGANer_constructor = RealESRGANer
+ for i, model in enumerate(realesrgan_models):
+ modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
+
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
|