aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-31 19:52:32 +0200
committerAarni Koskela <akx@iki.fi>2024-01-02 10:44:38 +0200
commit2cacbc124c49f45da5b66b79d9b0a3ab943472eb (patch)
tree27ae05c022710aaa60e7791ca6bde3e7f60b7511
parent51f1cca8524d3ffa8930b32a571d239c60d65725 (diff)
load_spandrel_model: make `half` `prefer_half`
As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not.
-rw-r--r--modules/modelloader.py20
-rw-r--r--modules/realesrgan_model.py2
2 files changed, 15 insertions, 7 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index a7194137..e100bb24 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -139,23 +139,31 @@ def load_upscalers():
def load_spandrel_model(
- path: str,
+ path: str | os.PathLike,
*,
device: str | torch.device | None,
- half: bool = False,
+ prefer_half: bool = False,
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor:
import spandrel
- model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
- if half:
- model_descriptor.model.half()
+ half = False
+ if prefer_half:
+ if model_descriptor.supports_half:
+ model_descriptor.model.half()
+ half = True
+ else:
+ logger.info("Model %s does not support half precision, ignoring --half", path)
if dtype:
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
- logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
+ logger.debug(
+ "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
+ model_descriptor, path, device, half, dtype,
+ )
return model_descriptor
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 4d35b695..ff9d8ac0 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
- half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(