aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py45
1 files changed, 36 insertions, 9 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 7e853287..82157bc9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -50,9 +50,9 @@ def apply_color_correction(correction, original_image):
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
-
+
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
-
+
return image
@@ -466,9 +466,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
- if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ if k == 'sd_hypernetwork':
+ shared.reload_hypernetworks() # make onchange call for changing hypernet
+
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights() # make onchange call for changing SD model
+ p.sd_model = shared.sd_model
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p)
@@ -538,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
+ def get_conds_with_caching(function, required_prompts, steps, cache):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+ """
+
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -565,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
- with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
- c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -683,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = 0
self.truncate_y = 0
-
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
if self.hr_resize_x == 0 and self.hr_resize_y == 0: