aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py7
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py5
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py8
-rw-r--r--modules/esrgan_model.py4
-rw-r--r--modules/modelloader.py29
-rw-r--r--modules/realesrgan_model.py4
6 files changed, 39 insertions, 18 deletions
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index dbd6d331..bf9b6de2 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,7 +1,6 @@
import os
-from basicsr.utils.download_util import load_file_from_url
-
+from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
@@ -43,9 +42,9 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
- model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
+ model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
- yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
+ yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
try:
return LDSR(model, yaml)
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 85b4505f..2785b551 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -6,12 +6,11 @@ import numpy as np
import torch
from tqdm import tqdm
-from basicsr.utils.download_util import load_file_from_url
-
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
+from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -120,7 +119,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
+ filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 1c7bf325..a5b0e2eb 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -3,7 +3,6 @@ import os
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
@@ -50,8 +49,11 @@ class UpscalerSwinIR(Upscaler):
def load_model(self, path, scale=4):
if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
+ filename = modelloader.load_file_from_url(
+ url=path,
+ model_dir=self.model_download_path,
+ file_name=f"{self.model_name.replace(' ', '_')}.pth",
+ )
else:
filename = path
if filename is None or not os.path.exists(filename):
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2fced999..f1a98c07 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -3,7 +3,6 @@ import os
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
@@ -152,11 +151,10 @@ class UpscalerESRGAN(Upscaler):
def load_model(self, path: str):
if "http" in path:
- filename = load_file_from_url(
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
diff --git a/modules/modelloader.py b/modules/modelloader.py
index be23071a..a69c8a4f 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 2d27b321..0d9c2e48 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -2,7 +2,6 @@ import os
import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -10,6 +9,7 @@ from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
self.name = "RealESRGAN"
@@ -71,7 +71,7 @@ class UpscalerRealESRGAN(Upscaler):
return None
if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
+ info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path)
return info
except Exception: