aboutsummaryrefslogtreecommitdiff
path: root/modules/upscaler_utils.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-31 09:41:49 +0300
committerGitHub <noreply@github.com>2023-12-31 09:41:49 +0300
commita84e842189f5599fd354147f72d1a9b9ed0716c8 (patch)
treeae0e5e9df369eb1cefa41ee76eb0e56fe945d192 /modules/upscaler_utils.py
parentce21840a042b9454a136372ab2971c1f21ec51e0 (diff)
parent6f86b62a1be7993073ba3a789d522e0b8870605a (diff)
Merge pull request #14476 from akx/dedupe-tiled-weighted-inference
Deduplicate tiled inference code from SwinIR/ScuNET
Diffstat (limited to 'modules/upscaler_utils.py')
-rw-r--r--modules/upscaler_utils.py72
1 files changed, 71 insertions, 1 deletions
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 174c9bc3..8e413854 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -6,7 +6,7 @@ import torch
import tqdm
from PIL import Image
-from modules import images
+from modules import images, shared
logger = logging.getLogger(__name__)
@@ -68,3 +68,73 @@ def upscale_with_model(
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)
+
+
+def tiled_upscale_2(
+ img,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ device,
+ desc="Tiled upscale",
+):
+ # Alternative implementation of `upscale_with_model` originally used by
+ # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
+ # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
+ # Pillow space without weighting.
+ b, c, h, w = img.size()
+ tile_size = min(tile_size, h, w)
+
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img.shape)
+ return model(img)
+
+ stride = tile_size - tile_overlap
+ h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
+ w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
+ result = torch.zeros(
+ b,
+ c,
+ h * scale,
+ w * scale,
+ device=device,
+ ).type_as(img)
+ weights = torch.zeros_like(result)
+ logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
+ with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
+ for h_idx in h_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ for w_idx in w_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ in_patch = img[
+ ...,
+ h_idx : h_idx + tile_size,
+ w_idx : w_idx + tile_size,
+ ]
+ out_patch = model(in_patch)
+
+ result[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch)
+
+ out_patch_mask = torch.ones_like(out_patch)
+
+ weights[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch_mask)
+
+ pbar.update(1)
+
+ output = result.div_(weights)
+
+ return output