aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py2
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py1
-rw-r--r--modules/codeformer_model.py1
-rw-r--r--modules/esrgan_model.py1
-rw-r--r--modules/gfpgan_model.py1
-rw-r--r--modules/hat_model.py1
-rw-r--r--modules/modelloader.py13
-rw-r--r--modules/realesrgan_model.py7
8 files changed, 22 insertions, 5 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 18cf8e1a..5f3dd08b 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- return modelloader.load_spandrel_model(filename, device=device)
+ return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings():
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 85c18b9e..aae159af 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
filename,
device=self._get_device(),
dtype=devices.dtype,
+ expected_architecture="SwinIR",
)
if getattr(opts, 'SWIN_torch_compile', False):
try:
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ceda4bab..44b84618 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
+ expected_architecture='CodeFormer',
).model
raise ValueError("No codeformer model found")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a7c7c9e3..70041ab0 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
return modelloader.load_spandrel_model(
filename,
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+ expected_architecture='ESRGAN',
)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index a356b56f..48f8ad5e 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
net = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
+ expected_architecture='GFPGAN',
).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net
diff --git a/modules/hat_model.py b/modules/hat_model.py
index 553e1941..7f2abb41 100644
--- a/modules/hat_model.py
+++ b/modules/hat_model.py
@@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 30116932..f4182559 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -6,6 +6,8 @@ import shutil
import importlib
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -183,9 +185,18 @@ def load_upscalers():
)
-def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
+def load_spandrel_model(
+ path: str,
+ *,
+ device: str | torch.device | None,
+ half: bool = False,
+ dtype: str | None = None,
+ expected_architecture: str | None = None,
+):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model.architecture != expected_architecture:
+ raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half:
model = model.model.half()
if dtype:
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 332d8f4b..2a2be5ad 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,9 +1,9 @@
import os
-from modules.upscaler_utils import upscale_with_model
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="RealESRGAN",
)
return upscale_with_model(
mod,