aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2024-01-03 22:38:13 +0200
committerAarni Koskela <akx@iki.fi>2024-01-03 22:38:13 +0200
commitdfdc51246c678b585e1bdfdb7d2f202b0ca0e362 (patch)
tree1aec5ae1bda1a0118bd1f332b562f31b7830e7a3 /extensions-builtin
parente4dcdcc9554d7ff56993f5019eb90fe4ddf1e2e7 (diff)
SwinIR: use prefer_half
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index bc427fea..6a8e21b0 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,6 +1,7 @@
import logging
import sys
+import torch
from PIL import Image
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
@@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler):
model_descriptor = modelloader.load_spandrel_model(
filename,
device=self._get_device(),
- dtype=devices.dtype,
+ prefer_half=(devices.dtype == torch.float16),
expected_architecture="SwinIR",
)
if getattr(shared.opts, 'SWIN_torch_compile', False):