From 740070ea9cdb254209f66417418f2a4af8b099d6 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 26 Sep 2022 09:29:50 -0500 Subject: Re-implement universal model loading --- modules/realesrgan_model.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) (limited to 'modules/realesrgan_model.py') diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c32d6c4c..458bf678 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,14 +1,20 @@ +import os import sys import traceback from collections import namedtuple import numpy as np from PIL import Image +from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer import modules.images +from modules.paths import models_path from modules.shared import cmd_opts, opts +model_dir = "RealESRGAN" +model_path = os.path.join(models_path, model_dir) +cmd_dir = None RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) realesrgan_models = [] have_realesrgan = False @@ -17,7 +23,6 @@ have_realesrgan = False def get_realesrgan_models(): try: from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact models = [ RealesrganModelInfo( @@ -59,7 +64,7 @@ def get_realesrgan_models(): ] return models except Exception as e: - print("Error makeing Real-ESRGAN midels list:", file=sys.stderr) + print("Error making Real-ESRGAN models list:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler): return upscale_with_realesrgan(img, self.upscaling, self.model_index) -def setup_realesrgan(): +def setup_model(dirname): + global model_path + if not os.path.exists(model_path): + os.makedirs(model_path) + global realesrgan_models global have_realesrgan - + if model_path != dirname: + model_path = dirname try: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer @@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) info = realesrgan_models[RealESRGAN_model_index] model = info.model() + model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True) + if not os.path.exists(model_file): + print("Unable to load RealESRGAN model: %s" % info.name) + return image + upsampler = RealESRGANer( scale=info.netscale, model_path=info.location, -- cgit v1.2.1