aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-04 11:17:20 +0300
committerGitHub <noreply@github.com>2024-01-04 11:17:20 +0300
commit3f7f61e5411320a802a2a4b1afd38c8affee94b3 (patch)
tree5473de8de4a7ffa71762b33b36d38f3ca92b821f /modules
parent1e7a8ce5e403de4bef7e09f522a48ce5c1b1d845 (diff)
parent62470ee23443cb2ad3943a152ccae26a689c86e1 (diff)
Merge pull request #14524 from akx/fix-swinir-issues
Fix SwinIR issues
Diffstat (limited to 'modules')
-rw-r--r--modules/upscaler_utils.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 4f1417cf..afed8b40 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -94,6 +94,7 @@ def tiled_upscale_2(
tile_size: int,
tile_overlap: int,
scale: int,
+ device: torch.device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
@@ -101,9 +102,6 @@ def tiled_upscale_2(
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
- # Grab the device the model is on, and use it.
- device = torch_utils.get_param(model).device
-
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)
@@ -175,7 +173,8 @@ def upscale_2(
"""
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
"""
- tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
+ param = torch_utils.get_param(model)
+ tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
with torch.no_grad():
output = tiled_upscale_2(
@@ -185,5 +184,6 @@ def upscale_2(
tile_overlap=tile_overlap,
scale=scale,
desc=desc,
+ device=param.device,
)
return torch_bgr_to_pil_image(output)