aboutsummaryrefslogtreecommitdiff
path: root/modules/modelloader.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-25 14:43:51 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 16:24:01 +0200
commitb0f59342346b1c8b405f97c0e0bb01c6ae05c601 (patch)
tree8f77ec512bf8c3352d03898cf9bf1c26df02c1a0 /modules/modelloader.py
parente472383acbb9e07dca311abe5fb16ee2675e410a (diff)
Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR)
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 098bcb79..30116932 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import logging
import os
import shutil
import importlib
@@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+logger = logging.getLogger(__name__)
+
+
def load_file_from_url(
url: str,
*,
@@ -177,3 +181,15 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
+
+
+def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
+ import spandrel
+ model = spandrel.ModelLoader(device=device).load_from_file(path)
+ if half:
+ model = model.model.half()
+ if dtype:
+ model = model.model.to(dtype=dtype)
+ model.eval()
+ logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
+ return model