aboutsummaryrefslogtreecommitdiff
path: root/modules/ldsr_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ldsr_model.py')
-rw-r--r--modules/ldsr_model.py92
1 files changed, 35 insertions, 57 deletions
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py
index 95e84659..969d1a0d 100644
--- a/modules/ldsr_model.py
+++ b/modules/ldsr_model.py
@@ -1,67 +1,45 @@
import os
import sys
import traceback
-from collections import namedtuple
from basicsr.utils.download_util import load_file_from_url
-import modules.images
+from modules.upscaler import Upscaler, UpscalerData
+from modules.ldsr_model_arch import LDSR
from modules import shared
-from modules.paths import script_path
+from modules.paths import models_path
-LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
-ldsr_models = []
-have_ldsr = False
-LDSR_obj = None
-
-
-class UpscalerLDSR(modules.images.Upscaler):
- def __init__(self, steps):
- self.steps = steps
+class UpscalerLDSR(Upscaler):
+ def __init__(self, user_path):
self.name = "LDSR"
-
- def do_upscale(self, img):
- return upscale_with_ldsr(img)
-
-
-def add_lsdr():
- modules.shared.sd_upscalers.append(UpscalerLDSR(100))
-
-
-def setup_ldsr():
- path = modules.paths.paths.get("LDSR", None)
- if path is None:
- return
- global have_ldsr
- global LDSR_obj
- try:
- from LDSR import LDSR
- model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
- yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
- repo_path = 'latent-diffusion/experiments/pretrained_models/'
- model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
- progress=True, file_name="model.chkpt")
- yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
- progress=True, file_name="project.yaml")
- have_ldsr = True
- LDSR_obj = LDSR(model_path, yaml_path)
-
-
- except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- have_ldsr = False
-
-
-def upscale_with_ldsr(image):
- setup_ldsr()
- if not have_ldsr or LDSR_obj is None:
- return image
-
- ddim_steps = shared.opts.ldsr_steps
- pre_scale = shared.opts.ldsr_pre_down
- post_scale = shared.opts.ldsr_post_down
-
- image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
- return image
+ self.model_path = os.path.join(models_path, self.name)
+ self.user_path = user_path
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
+ super().__init__()
+ scaler_data = UpscalerData("LDSR", None, self)
+ self.scalers = [scaler_data]
+
+ def load_model(self, path: str):
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="model.pth", progress=True)
+ yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="project.yaml", progress=True)
+
+ try:
+ return LDSR(model, yaml)
+
+ except Exception:
+ print("Error importing LDSR:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ def do_upscale(self, img, path):
+ ldsr = self.load_model(path)
+ if ldsr is None:
+ print("NO LDSR!")
+ return img
+ ddim_steps = shared.opts.ldsr_steps
+ pre_scale = shared.opts.ldsr_pre_down
+ return ldsr.super_resolution(img, ddim_steps, self.scale)