aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ESRGAN/Put ESRGAN models here.txt0
-rw-r--r--README.md10
-rw-r--r--modules/esrgam_model_arch.py80
-rw-r--r--modules/esrgan_model.py134
-rw-r--r--modules/images.py38
-rw-r--r--modules/img2img.py6
-rw-r--r--modules/realesrgan_model.py15
-rw-r--r--modules/shared.py7
-rw-r--r--modules/ui.py25
-rw-r--r--webui.py43
10 files changed, 327 insertions, 31 deletions
diff --git a/ESRGAN/Put ESRGAN models here.txt b/ESRGAN/Put ESRGAN models here.txt
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/ESRGAN/Put ESRGAN models here.txt
diff --git a/README.md b/README.md
index 610826c2..6cf246d2 100644
--- a/README.md
+++ b/README.md
@@ -19,11 +19,14 @@ Original script with Gradio UI was written by a kind anonymous user. This is a m
- Loopback
- X/Y plot
- Textual Inversion
-- Resizing options
+- Extras tab with:
+ - GFPGAN, neural network that fixes faces
+ - RealESRGAN, neural network upscaler
+ - ESRGAN, neural network with a lot of third party models
+- Resizing aspect ratio options
- Sampling method selection
- Interrupt processing at any time
- 4GB videocard support
-- Option to use GFPGAN
- Correct seeds for batches
- Prompt length validation
- Generation parameters added as text to PNG
@@ -49,6 +52,9 @@ can obtain it from the following places:
You optionally can use GPFGAN to improve faces, then you'll need to download the model from [here](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth).
+To use ESRGAN models, put them into ESRGAN directory in the same location as webui.py. A file will be loaded
+as model if it has .pth extension. Grab models from the [Model Database](https://upscale.wiki/wiki/Model_Database).
+
### Automatic installation/launch
- install [Python 3.10.6](https://www.python.org/downloads/windows/)
diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py
new file mode 100644
index 00000000..e413d36e
--- /dev/null
+++ b/modules/esrgam_model_arch.py
@@ -0,0 +1,80 @@
+# this file is taken from https://github.com/xinntao/ESRGAN
+
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def make_layer(block, n_layers):
+ layers = []
+ for _ in range(n_layers):
+ layers.append(block())
+ return nn.Sequential(*layers)
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ def __init__(self, nf=64, gc=32, bias=True):
+ super(ResidualDenseBlock_5C, self).__init__()
+ # gc: growth channel, i.e. intermediate channels
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ '''Residual in Residual Dense Block'''
+
+ def __init__(self, nf, gc=32):
+ super(RRDB, self).__init__()
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+
+ def forward(self, x):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ return out * 0.2 + x
+
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
+ super(RRDBNet, self).__init__()
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ #### upsampling
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ fea = self.conv_first(x)
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
+ fea = fea + trunk
+
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
+
+ return out
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
new file mode 100644
index 00000000..3dcef5a6
--- /dev/null
+++ b/modules/esrgan_model.py
@@ -0,0 +1,134 @@
+import os
+import sys
+import traceback
+
+import numpy as np
+import torch
+from PIL import Image
+
+import modules.esrgam_model_arch as arch
+from modules import shared
+from modules.shared import opts
+import modules.images
+
+
+def load_model(filename):
+ # this code is adapted from https://github.com/xinntao/ESRGAN
+
+ pretrained_net = torch.load(filename)
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+
+ if 'conv_first.weight' in pretrained_net:
+ crt_model.load_state_dict(pretrained_net)
+ return crt_model
+
+ crt_net = crt_model.state_dict()
+ load_net_clean = {}
+ for k, v in pretrained_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ pretrained_net = load_net_clean
+
+ tbd = []
+ for k, v in crt_net.items():
+ tbd.append(k)
+
+ # directly copy
+ for k, v in crt_net.items():
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
+ crt_net[k] = pretrained_net[k]
+ tbd.remove(k)
+
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
+
+ for k in tbd.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[k] = pretrained_net[ori_k]
+ tbd.remove(k)
+
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+
+ crt_model.load_state_dict(crt_net)
+ crt_model.eval()
+ return crt_model
+
+def upscale_without_tiling(model, img):
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.moveaxis(img, 2, 0) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(shared.device)
+ with torch.no_grad():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ return Image.fromarray(output, 'RGB')
+
+
+def esrgan_upscale(model, img):
+ if opts.ESRGAN_tile == 0:
+ return upscale_without_tiling(model, img)
+
+ grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ newtiles = []
+ scale_factor = 1
+
+ for y, h, row in grid.tiles:
+ newrow = []
+ for tiledata in row:
+ x, w, tile = tiledata
+
+ output = upscale_without_tiling(model, tile)
+ scale_factor = output.width // tile.width
+
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = modules.images.combine_grid(newgrid)
+ return output
+
+
+class UpscalerESRGAN(modules.images.Upscaler):
+ def __init__(self, filename, title):
+ self.name = title
+ self.model = load_model(filename)
+
+ def do_upscale(self, img):
+ model = self.model.to(shared.device)
+ img = esrgan_upscale(model, img)
+ return img
+
+
+def load_models(dirname):
+ for file in os.listdir(dirname):
+ path = os.path.join(dirname, file)
+ model_name, extension = os.path.splitext(file)
+
+ if extension != '.pt' and extension != '.pth':
+ continue
+
+ try:
+ modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
+ except Exception:
+ print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/images.py b/modules/images.py
index 4b9667d2..4226db00 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -6,6 +6,7 @@ import re
import numpy as np
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+import modules.shared
from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -45,20 +46,20 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height)
- dx = (w - tile_w) // (cols-1) if cols > 1 else 0
- dy = (h - tile_h) // (rows-1) if rows > 1 else 0
+ dx = (w - tile_w) / (cols-1) if cols > 1 else 0
+ dy = (h - tile_h) / (rows-1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
row_images = []
- y = row * dy
+ y = int(row * dy)
if y + tile_h >= h:
y = h - tile_h
for col in range(cols):
- x = col * dx
+ x = int(col * dx)
if x+tile_w >= w:
x = w - tile_w
@@ -291,3 +292,32 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
file.write(info + "\n")
+
+class Upscaler:
+ name = "Lanczos"
+
+ def do_upscale(self, img):
+ return img
+
+ def upscale(self, img, w, h):
+ for i in range(3):
+ if img.width >= w and img.height >= h:
+ break
+
+ img = self.do_upscale(img)
+
+ if img.width != w or img.height != h:
+ img = img.resize((w, h), resample=LANCZOS)
+
+ return img
+
+
+class UpscalerNone(Upscaler):
+ name = "None"
+
+ def upscale(self, img, w, h):
+ return img
+
+
+modules.shared.sd_upscalers.append(UpscalerNone())
+modules.shared.sd_upscalers.append(Upscaler())
diff --git a/modules/img2img.py b/modules/img2img.py
index d5787dd3..b1ef1326 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
-def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
+def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
@@ -81,8 +81,8 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
initial_seed = None
initial_info = None
- upscaler = shared.sd_upscalers.get(upscaler_name, next(iter(shared.sd_upscalers.values())))
- img = upscaler(init_img)
+ upscaler = shared.sd_upscalers[upscaler_index]
+ img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
processing.torch_gc()
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 5a6666a3..e480887f 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -4,6 +4,7 @@ from collections import namedtuple
import numpy as np
from PIL import Image
+import modules.images
from modules.shared import cmd_opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@@ -12,6 +13,17 @@ realesrgan_models = []
have_realesrgan = False
RealESRGANer_constructor = None
+
+class UpscalerRealESRGAN(modules.images.Upscaler):
+ def __init__(self, upscaling, model_index):
+ self.upscaling = upscaling
+ self.model_index = model_index
+ self.name = realesrgan_models[model_index].name
+
+ def do_upscale(self, img):
+ return upscale_with_realesrgan(img, self.upscaling, self.model_index)
+
+
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
@@ -42,6 +54,9 @@ def setup_realesrgan():
have_realesrgan = True
RealESRGANer_constructor = RealESRGANer
+ for i, model in enumerate(realesrgan_models):
+ modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
+
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/shared.py b/modules/shared.py
index c8c2749a..72e92eb9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -28,6 +28,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a w
parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN every time after processing images. Warning: seems to cause memory leaks")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
cmd_opts = parser.parse_args()
cpu = torch.device("cpu")
@@ -79,7 +80,8 @@ class Options:
"font": OptionInfo("arial.ttf", "Font for image grids that have text"),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
-
+ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
}
def __init__(self):
@@ -115,7 +117,6 @@ opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
-
-sd_upscalers = {}
+sd_upscalers = []
sd_model = None
diff --git a/modules/ui.py b/modules/ui.py
index d6b39c2f..4119369e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -256,10 +256,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
+ sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
with gr.Row():
- sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(shared.sd_upscalers.keys()), value=list(shared.sd_upscalers.keys())[0], visible=False)
- sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
+ sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
with gr.Row():
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
@@ -401,9 +401,18 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Column(variant='panel'):
with gr.Group():
image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
- gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=gfpgan.have_gfpgan)
- realesrgan_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=realesrgan.have_realesrgan)
- realesrgan_model = gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan.realesrgan_models], value=realesrgan.realesrgan_models[0].name, type="index", interactive=realesrgan.have_realesrgan)
+
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+
+ with gr.Group():
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+
+ with gr.Group():
+ extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
+
+ with gr.Group():
+ gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=0, interactive=gfpgan.have_gfpgan)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
@@ -417,8 +426,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inputs=[
image,
gfpgan_strength,
- realesrgan_resize,
- realesrgan_model,
+ upscaling_resize,
+ extras_upscaler_1,
+ extras_upscaler_2,
+ extras_upscaler_2_visibility,
],
outputs=[
result_image,
diff --git a/webui.py b/webui.py
index d79b5966..dbc9dd54 100644
--- a/webui.py
+++ b/webui.py
@@ -21,17 +21,14 @@ import modules.processing as processing
import modules.sd_hijack
import modules.gfpgan_model as gfpgan
import modules.realesrgan_model as realesrgan
+import modules.esrgan_model as esrgan
import modules.images as images
import modules.lowvram
import modules.txt2img
import modules.img2img
-shared.sd_upscalers = {
- "RealESRGAN": lambda img: realesrgan.upscale_with_realesrgan(img, 2, 0),
- "Lanczos": lambda img: img.resize((img.width*2, img.height*2), resample=images.LANCZOS),
- "None": lambda img: img
-}
+esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan()
gfpgan.setup_gfpgan()
@@ -54,26 +51,48 @@ def load_model_from_config(config, ckpt, verbose=False):
model.eval()
return model
+cached_images = {}
-def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
+def run_extras(image, gfpgan_strength, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
processing.torch_gc()
image = image.convert("RGB")
outpath = opts.outdir_samples or opts.outdir_extras_samples
- if gfpgan.have_gfpgan is not None and GFPGAN_strength > 0:
-
+ if gfpgan.have_gfpgan is not None and gfpgan_strength > 0:
restored_img = gfpgan.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
- if GFPGAN_strength < 1.0:
- res = Image.blend(image, res, GFPGAN_strength)
+ if gfpgan_strength < 1.0:
+ res = Image.blend(image, res, gfpgan_strength)
+
+ image = res
+
+ if upscaling_resize != 1.0:
+ def upscale(image, scaler_index, resize):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height) + pixels
+
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.upscale(image, image.width * resize, image.height * resize)
+ cached_images[key] = c
+
+ return c
+
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
+
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility>0:
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
image = res
- if realesrgan.have_realesrgan and RealESRGAN_upscaling != 1.0:
- image = realesrgan.upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
+ while len(cached_images) > 2:
+ del cached_images[next(iter(cached_images.keys()))]
images.save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True, no_prompt=True)