From dec5cdd9b89dd683f04fb904ebd8a56dfce860ae Mon Sep 17 00:00:00 2001 From: AdjointOperator Date: Wed, 19 Apr 2023 15:35:50 +0800 Subject: add tiled inference support for ScuNET --- extensions-builtin/ScuNET/scripts/scunet_model.py | 83 +++++++++++++++++++---- 1 file changed, 68 insertions(+), 15 deletions(-) (limited to 'extensions-builtin/ScuNET') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index e0fbf3a3..c7fd5739 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -5,11 +5,15 @@ import traceback import PIL.Image import numpy as np import torch +from tqdm import tqdm + from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader from scunet_model_arch import SCUNet as net +from modules.shared import opts +from modules import images class UpscalerScuNET(modules.upscaler.Upscaler): @@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - def do_upscale(self, img: PIL.Image, selected_file): + @staticmethod + @torch.no_grad() + def tiled_inference(img, model): + # test the image tile by tile + h, w = img.shape[2:] + tile = opts.SCUNET_tile + tile_overlap = opts.SCUNET_tile_overlap + if tile == 0: + return model(img) + + device = devices.get_device_for('scunet') + assert tile % 8 == 0, "tile size should be a multiple of window_size" + sf = 1 + + stride = tile - tile_overlap + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) + + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: + for h_idx in h_idx_list: + + for w_idx in w_idx_list: + + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) + output = E.div_(W) + + return output + + def do_upscale(self, img: PIL.Image.Image, selected_file): + torch.cuda.empty_cache() model = self.load_model(selected_file) if model is None: + print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) return img device = devices.get_device_for('scunet') - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) - - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] + tile = opts.SCUNET_tile + h, w = img.height, img.width + np_img = np.array(img) + np_img = np_img[:, :, ::-1] # RGB to BGR + np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW + torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore + + if tile > h or tile > w: + _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) + _img[:, :, :h, :w] = torch_img # pad image + torch_img = _img + + torch_output = self.tiled_inference(torch_img, model).squeeze(0) + torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any + np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() + del torch_img, torch_output torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') + + output = np_output.transpose((1, 2, 0)) # CHW to HWC + output = output[:, :, ::-1] # BGR to RGB + return PIL.Image.fromarray((output * 255).astype(np.uint8)) def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler): model = model.to(device) return model - -- cgit v1.2.1