aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/modelloader.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 0b89d682..8bcee08c 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,8 +1,9 @@
from __future__ import annotations
+import importlib
import logging
import os
-import importlib
+from typing import TYPE_CHECKING
from urllib.parse import urlparse
import torch
@@ -10,6 +11,8 @@ import torch
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
+if TYPE_CHECKING:
+ import spandrel
logger = logging.getLogger(__name__)
@@ -142,17 +145,17 @@ def load_spandrel_model(
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
-):
+) -> spandrel.ModelDescriptor:
import spandrel
- model = spandrel.ModelLoader(device=device).load_from_file(path)
- if expected_architecture and model.architecture != expected_architecture:
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
- f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})",
+ f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
if half:
- model = model.model.half()
+ model_descriptor.model.half()
if dtype:
- model = model.model.to(dtype=dtype)
- model.eval()
- logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
- return model
+ model_descriptor.model.to(dtype=dtype)
+ model_descriptor.model.eval()
+ logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
+ return model_descriptor