aboutsummaryrefslogtreecommitdiff
path: root/modules/modelloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index a7194137..e100bb24 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -139,23 +139,31 @@ def load_upscalers():
def load_spandrel_model(
- path: str,
+ path: str | os.PathLike,
*,
device: str | torch.device | None,
- half: bool = False,
+ prefer_half: bool = False,
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor:
import spandrel
- model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
- if half:
- model_descriptor.model.half()
+ half = False
+ if prefer_half:
+ if model_descriptor.supports_half:
+ model_descriptor.model.half()
+ half = True
+ else:
+ logger.info("Model %s does not support half precision, ignoring --half", path)
if dtype:
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
- logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
+ logger.debug(
+ "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
+ model_descriptor, path, device, half, dtype,
+ )
return model_descriptor