From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: Refactor esrgan_upscale to more generic upscale_with_model --- modules/esrgan_model.py | 47 ++++++--------------------------- modules/upscaler_utils.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 39 deletions(-) create mode 100644 modules/upscaler_utils.py diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d..c0d22a99 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,13 +1,12 @@ import sys -import numpy as np import torch -from PIL import Image import modules.esrgan_model_arch as arch -from modules import modelloader, images, devices +from modules import modelloader, devices from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model def mod2normal(state_dict): @@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler): return model -def upscale_without_tiling(model, img): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') - - def esrgan_upscale(model, img): - if opts.ESRGAN_tile == 0: - return upscale_without_tiling(model, img) - - grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) - newtiles = [] - scale_factor = 1 - - for y, h, row in grid.tiles: - newrow = [] - for tiledata in row: - x, w, tile = tiledata - - output = upscale_without_tiling(model, tile) - scale_factor = output.width // tile.width - - newrow.append([x * scale_factor, w * scale_factor, output]) - newtiles.append([y * scale_factor, h * scale_factor, newrow]) - - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = images.combine_grid(newgrid) - return output + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + ) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 00000000..8bdda51c --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) -- cgit v1.2.1