aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-30 16:37:03 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 16:37:03 +0200
commit4ad0c0c0a805da4bac03cff86ea17c25a1291546 (patch)
tree9821621545c6989205074d7bd23137eacbbad0e2 /modules
parentc756133541da478a35a74cda416d114a8973cf8e (diff)
Verify architecture for loaded Spandrel models
Diffstat (limited to 'modules')
-rw-r--r--modules/codeformer_model.py1
-rw-r--r--modules/esrgan_model.py1
-rw-r--r--modules/gfpgan_model.py1
-rw-r--r--modules/hat_model.py1
-rw-r--r--modules/modelloader.py13
-rw-r--r--modules/realesrgan_model.py7
6 files changed, 20 insertions, 4 deletions
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ceda4bab..44b84618 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
+ expected_architecture='CodeFormer',
).model
raise ValueError("No codeformer model found")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a7c7c9e3..70041ab0 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
return modelloader.load_spandrel_model(
filename,
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+ expected_architecture='ESRGAN',
)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index a356b56f..48f8ad5e 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
net = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
+ expected_architecture='GFPGAN',
).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net
diff --git a/modules/hat_model.py b/modules/hat_model.py
index 553e1941..7f2abb41 100644
--- a/modules/hat_model.py
+++ b/modules/hat_model.py
@@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 30116932..f4182559 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -6,6 +6,8 @@ import shutil
import importlib
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -183,9 +185,18 @@ def load_upscalers():
)
-def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
+def load_spandrel_model(
+ path: str,
+ *,
+ device: str | torch.device | None,
+ half: bool = False,
+ dtype: str | None = None,
+ expected_architecture: str | None = None,
+):
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model.architecture != expected_architecture:
+ raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half:
model = model.model.half()
if dtype:
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 332d8f4b..2a2be5ad 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,9 +1,9 @@
import os
-from modules.upscaler_utils import upscale_with_model
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="RealESRGAN",
)
return upscale_with_model(
mod,