aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/extras.py67
-rw-r--r--modules/ui.py5
2 files changed, 43 insertions, 29 deletions
diff --git a/modules/extras.py b/modules/extras.py
index cffe0381..72cc6d1d 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import math
import os
@@ -7,7 +8,7 @@ from PIL import Image
import torch
import tqdm
-from typing import Callable, Dict, List, Tuple
+from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
@@ -21,18 +22,34 @@ import piexif.helper
import gradio as gr
-@dataclass(frozen=True)
-class CacheKey:
- image_hash: int
- info_hash: int
- args_hash: int
+class LruCache(OrderedDict):
+ @dataclass(frozen=True)
+ class Key:
+ image_hash: int
+ info_hash: int
+ args_hash: int
-@dataclass
-class CacheEntry:
- image: Image.Image
- info: str
+ @dataclass
+ class Value:
+ image: Image.Image
+ info: str
+
+ def __init__(self, max_size:int = 5, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._max_size = max_size
+
+ def get(self, key: LruCache.Key) -> LruCache.Value:
+ ret = super().get(key)
+ if ret is not None:
+ self.move_to_end(key) # Move to end of eviction list
+ return ret
+
+ def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
+ self[key] = value
+ while len(self) > self._max_size:
+ self.popitem(last=False)
-cached_images: Dict[CacheKey, CacheEntry] = {}
+cached_images: LruCache = LruCache(max_size = 5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
@@ -121,14 +138,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
blended_result: Image.Image = None
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()),
+ cache_key = LruCache.Key( image_hash = hash(np.array(image.getdata()).tobytes()),
info_hash = hash(info),
- args_hash = hash(upscale_args) )
+ args_hash = hash(upscale_args + (upscaler.blend_alpha,)) )
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
- cached_images[cache_key] = CacheEntry(image=res, info=info)
+ cached_images.put(cache_key, LruCache.Value(image=res, info=info))
else:
res, info = cached_entry.image, cached_entry.info
@@ -140,14 +157,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Build a list of operations to run
facefix_ops: List[Callable] = []
- if gfpgan_visibility > 0:
- facefix_ops.append(run_gfpgan)
- if codeformer_visibility > 0:
- facefix_ops.append(run_codeformer)
+ facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
+ facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
upscale_ops: List[Callable] = []
- if resize_mode == 1:
- upscale_ops.append(run_prepare_crop)
+ upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
if upscaling_resize != 0:
step_params: List[UpscaleParams] = []
@@ -157,12 +171,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscale_ops.append( partial(run_upscalers_blend, step_params) )
-
- extras_ops: List[Callable] = []
- if upscale_first:
- extras_ops = upscale_ops + facefix_ops
- else:
- extras_ops = facefix_ops + upscale_ops
+ extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
for image, image_name in zip(imageArr, imageNameArr):
@@ -176,9 +185,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for op in extras_ops:
image, info = op(image, info)
- while len(cached_images) > 2:
- del cached_images[next(iter(cached_images.keys()))]
-
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
@@ -198,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, plaintext_to_html(info), ''
+def clear_cache():
+ cached_images.clear()
+
def run_pnginfo(image):
if image is None:
diff --git a/modules/ui.py b/modules/ui.py
index 16b6ac49..b7c36c55 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1178,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[init_img_with_mask],
)
+ extras_image.change(
+ fn=modules.extras.clear_cache,
+ inputs=[], outputs=[]
+ )
+
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):