aboutsummaryrefslogtreecommitdiff
path: root/modules/modelloader.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-30 16:37:03 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 16:37:03 +0200
commit4ad0c0c0a805da4bac03cff86ea17c25a1291546 (patch)
tree9821621545c6989205074d7bd23137eacbbad0e2 /modules/modelloader.py
parentc756133541da478a35a74cda416d114a8973cf8e (diff)
Verify architecture for loaded Spandrel models
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 30116932..f4182559 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -6,6 +6,8 @@ import shutil
import importlib
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -183,9 +185,18 @@ def load_upscalers():
)
-def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
+def load_spandrel_model(
+ path: str,
+ *,
+ device: str | torch.device | None,
+ half: bool = False,
+ dtype: str | None = None,
+ expected_architecture: str | None = None,
+):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model.architecture != expected_architecture:
+ raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half:
model = model.model.half()
if dtype: