aboutsummaryrefslogtreecommitdiff
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-27 11:04:33 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 16:24:01 +0200
commite472383acbb9e07dca311abe5fb16ee2675e410a (patch)
tree69591965d87134116235daa785d31f60b70791b4 /modules/esrgan_model.py
parent12c6f37f8e4b1d1d643c9d8d5dfc763c3203c728 (diff)
Refactor esrgan_upscale to more generic upscale_with_model
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py47
1 files changed, 8 insertions, 39 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 02a1727d..c0d22a99 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,13 +1,12 @@
import sys
-import numpy as np
import torch
-from PIL import Image
import modules.esrgan_model_arch as arch
-from modules import modelloader, images, devices
+from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict):
@@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler):
return model
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
-
-
def esrgan_upscale(model, img):
- if opts.ESRGAN_tile == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
-
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
-
- newrow.append([x * scale_factor, w * scale_factor, output])
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ )