aboutsummaryrefslogtreecommitdiff
path: root/modules/realesrgan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/realesrgan_model.py')
-rw-r--r--modules/realesrgan_model.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 5a6666a3..e480887f 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -4,6 +4,7 @@ from collections import namedtuple
import numpy as np
from PIL import Image
+import modules.images
from modules.shared import cmd_opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@@ -12,6 +13,17 @@ realesrgan_models = []
have_realesrgan = False
RealESRGANer_constructor = None
+
+class UpscalerRealESRGAN(modules.images.Upscaler):
+ def __init__(self, upscaling, model_index):
+ self.upscaling = upscaling
+ self.model_index = model_index
+ self.name = realesrgan_models[model_index].name
+
+ def do_upscale(self, img):
+ return upscale_with_realesrgan(img, self.upscaling, self.model_index)
+
+
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
@@ -42,6 +54,9 @@ def setup_realesrgan():
have_realesrgan = True
RealESRGANer_constructor = RealESRGANer
+ for i, model in enumerate(realesrgan_models):
+ modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
+
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)