aboutsummaryrefslogtreecommitdiff
path: root/modules/modelloader.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-30 18:06:31 +0300
committerGitHub <noreply@github.com>2023-12-30 18:06:31 +0300
commitcd12c0e15c4dc1545cac18ba902ca17488812953 (patch)
tree9c70df74d3e426341d1189b1ceadbd8afffeae91 /modules/modelloader.py
parent05230c02606080527b65ace9eacb6fb835239877 (diff)
parent4ad0c0c0a805da4bac03cff86ea17c25a1291546 (diff)
Merge pull request #14425 from akx/spandrel
Use Spandrel for upscaling and face restoration architectures
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 098bcb79..f4182559 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,15 +1,21 @@
from __future__ import annotations
+import logging
import os
import shutil
import importlib
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
+logger = logging.getLogger(__name__)
+
+
def load_file_from_url(
url: str,
*,
@@ -177,3 +183,24 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
+
+
+def load_spandrel_model(
+ path: str,
+ *,
+ device: str | torch.device | None,
+ half: bool = False,
+ dtype: str | None = None,
+ expected_architecture: str | None = None,
+):
+ import spandrel
+ model = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model.architecture != expected_architecture:
+ raise TypeError(f"Model {path} is not a {expected_architecture} model")
+ if half:
+ model = model.model.half()
+ if dtype:
+ model = model.model.to(dtype=dtype)
+ model.eval()
+ logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
+ return model