aboutsummaryrefslogtreecommitdiff
path: root/modules/realesrgan_model.py
diff options
context:
space:
mode:
authord8ahazard <d8ahazard@gmail.com>2022-09-26 09:29:50 -0500
committerd8ahazard <d8ahazard@gmail.com>2022-09-26 09:29:50 -0500
commit740070ea9cdb254209f66417418f2a4af8b099d6 (patch)
tree52896a6159b706024af9520c855c10091162372c /modules/realesrgan_model.py
parentbfb7f15d46048f27338eeac3a591a5943d03c5f1 (diff)
Re-implement universal model loading
Diffstat (limited to 'modules/realesrgan_model.py')
-rw-r--r--modules/realesrgan_model.py23
1 files changed, 19 insertions, 4 deletions
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index c32d6c4c..458bf678 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,14 +1,20 @@
+import os
import sys
import traceback
from collections import namedtuple
import numpy as np
from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
import modules.images
+from modules.paths import models_path
from modules.shared import cmd_opts, opts
+model_dir = "RealESRGAN"
+model_path = os.path.join(models_path, model_dir)
+cmd_dir = None
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False
@@ -17,7 +23,6 @@ have_realesrgan = False
def get_realesrgan_models():
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
RealesrganModelInfo(
@@ -59,7 +64,7 @@ def get_realesrgan_models():
]
return models
except Exception as e:
- print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
+ print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler):
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
-def setup_realesrgan():
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
global realesrgan_models
global have_realesrgan
-
+ if model_path != dirname:
+ model_path = dirname
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
@@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
+ model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
+ if not os.path.exists(model_file):
+ print("Unable to load RealESRGAN model: %s" % info.name)
+ return image
+
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,