aboutsummaryrefslogtreecommitdiff
path: root/modules/swinir_model.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-02 21:33:22 +0300
committerGitHub <noreply@github.com>2022-10-02 21:33:22 +0300
commit688c4a914a6cb152b5f5e4088bace709ed3dcc35 (patch)
tree39c3221e99ca90b7238057bc555c5b4d63fcf960 /modules/swinir_model.py
parenta634c3226fd69486ce96df56f95f3fd63172305c (diff)
parent852fd90c0dcda9cb5fbbfdf0c7308ce58034935c (diff)
Merge branch 'master' into 1404-script-reload-without-restart
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r--modules/swinir_model.py27
1 files changed, 15 insertions, 12 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index 41fda5a7..9bd454c6 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -5,6 +5,7 @@ import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
+from tqdm import tqdm
from modules import modelloader
from modules.paths import models_path
@@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
- for h_idx in h_idx_list:
- for w_idx in w_idx_list:
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
+ for h_idx in h_idx_list:
+ for w_idx in w_idx_list:
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
+ out_patch = model(in_patch)
+ out_patch_mask = torch.ones_like(out_patch)
+
+ E[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch)
+ W[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch_mask)
+ pbar.update(1)
output = E.div_(W)
return output