aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py20
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py26
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py57
-rw-r--r--modules/esrgan_model.py23
-rw-r--r--modules/gfpgan_model.py2
-rw-r--r--modules/modelloader.py31
-rw-r--r--modules/realesrgan_model.py33
7 files changed, 101 insertions, 91 deletions
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index dbd6d331..bd78dece 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,7 +1,6 @@
import os
-from basicsr.utils.download_util import load_file_from_url
-
+from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
@@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
- model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
+ model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
- yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
+ yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
- try:
- return LDSR(model, yaml)
- except Exception:
- errors.report("Error importing LDSR", exc_info=True)
- return None
+ return LDSR(model, yaml)
def do_upscale(self, img, path):
- ldsr = self.load_model(path)
- if ldsr is None:
- print("NO LDSR!")
+ try:
+ ldsr = self.load_model(path)
+ except Exception:
+ errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 85b4505f..ffef26b2 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,4 +1,3 @@
-import os.path
import sys
import PIL.Image
@@ -6,12 +5,11 @@ import numpy as np
import torch
from tqdm import tqdm
-from basicsr.utils.download_util import load_file_from_url
-
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet
+from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = []
add_model2 = True
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -89,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
+ try:
+ model = self.load_model(selected_file)
+ except Exception as e:
+ print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
@@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
- print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
- return None
-
- model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+ model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 1c7bf325..c6bc53a8 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,17 +1,17 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR
+from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData
+SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir')
@@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
self.name = "SwinIR"
- self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
- "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
- "-L_x4_GAN.pth "
+ self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
- if "http" in model:
+ if model.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(model)
@@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img, model_file):
- model = self.load_model(model_file)
- if model is None:
+ try:
+ model = self.load_model(model_file)
+ except Exception as e:
+ print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
@@ -49,30 +49,31 @@ class UpscalerSwinIR(Upscaler):
return img
def load_model(self, path, scale=4):
- if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
+ if path.startswith("http"):
+ filename = modelloader.load_file_from_url(
+ url=path,
+ model_dir=self.model_download_path,
+ file_name=f"{self.model_name.replace(' ', '_')}.pth",
+ )
else:
filename = path
- if filename is None or not os.path.exists(filename):
- return None
if filename.endswith(".v2.pth"):
- model = net2(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
+ model = Swin2SR(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="1conv",
)
params = None
else:
- model = net(
+ model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2fced999..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,15 +1,13 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 6ecd295c..8e0f13bd 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -25,7 +25,7 @@ def gfpgann():
return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 75f01247..098bcb79 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str):
- if "http" in file:
+ if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 2d27b321..0700b853 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -2,7 +2,6 @@ import os
import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
- return info
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)