aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/LDSR/scripts/ldsr_model.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-04 17:40:19 +0300
committerGitHub <noreply@github.com>2023-01-04 17:40:19 +0300
commitda5c1e8a732c173ed8ccda9fa32f9a194ff91ab6 (patch)
treea2eec9c47e820e7ab351337f73c99d874b4b904f /extensions-builtin/LDSR/scripts/ldsr_model.py
parentcffc240a7327ae60671ff533469fc4ed4bf605de (diff)
parent47df0849019abac6722c49512f4dd2285bff5b7d (diff)
Merge branch 'master' into inpaint_textual_inversion
Diffstat (limited to 'extensions-builtin/LDSR/scripts/ldsr_model.py')
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
new file mode 100644
index 00000000..b8cff29b
--- /dev/null
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -0,0 +1,69 @@
+import os
+import sys
+import traceback
+
+from basicsr.utils.download_util import load_file_from_url
+
+from modules.upscaler import Upscaler, UpscalerData
+from ldsr_model_arch import LDSR
+from modules import shared, script_callbacks
+import sd_hijack_autoencoder, sd_hijack_ddpm_v1
+
+
+class UpscalerLDSR(Upscaler):
+ def __init__(self, user_path):
+ self.name = "LDSR"
+ self.user_path = user_path
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
+ super().__init__()
+ scaler_data = UpscalerData("LDSR", None, self)
+ self.scalers = [scaler_data]
+
+ def load_model(self, path: str):
+ # Remove incorrect project.yaml file if too big
+ yaml_path = os.path.join(self.model_path, "project.yaml")
+ old_model_path = os.path.join(self.model_path, "model.pth")
+ new_model_path = os.path.join(self.model_path, "model.ckpt")
+ safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
+ if os.path.exists(yaml_path):
+ statinfo = os.stat(yaml_path)
+ if statinfo.st_size >= 10485760:
+ print("Removing invalid LDSR YAML file.")
+ os.remove(yaml_path)
+ if os.path.exists(old_model_path):
+ print("Renaming model from model.pth to model.ckpt")
+ os.rename(old_model_path, new_model_path)
+ if os.path.exists(safetensors_model_path):
+ model = safetensors_model_path
+ else:
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="model.ckpt", progress=True)
+ yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
+ file_name="project.yaml", progress=True)
+
+ try:
+ return LDSR(model, yaml)
+
+ except Exception:
+ print("Error importing LDSR:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ def do_upscale(self, img, path):
+ ldsr = self.load_model(path)
+ if ldsr is None:
+ print("NO LDSR!")
+ return img
+ ddim_steps = shared.opts.ldsr_steps
+ return ldsr.super_resolution(img, ddim_steps, self.scale)
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
+ shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)