aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-30 16:37:03 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 16:37:03 +0200
commit4ad0c0c0a805da4bac03cff86ea17c25a1291546 (patch)
tree9821621545c6989205074d7bd23137eacbbad0e2 /extensions-builtin
parentc756133541da478a35a74cda416d114a8973cf8e (diff)
Verify architecture for loaded Spandrel models
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py2
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py1
2 files changed, 2 insertions, 1 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 18cf8e1a..5f3dd08b 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- return modelloader.load_spandrel_model(filename, device=device)
+ return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings():
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 85c18b9e..aae159af 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
filename,
device=self._get_device(),
dtype=devices.dtype,
+ expected_architecture="SwinIR",
)
if getattr(opts, 'SWIN_torch_compile', False):
try: