aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py261
1 files changed, 149 insertions, 112 deletions
diff --git a/modules/processing.py b/modules/processing.py
index dd14c486..29a3743f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -169,6 +169,16 @@ class StableDiffusionProcessing:
self.is_hr_pass = False
self.sampler = None
+ self.prompts = None
+ self.negative_prompts = None
+ self.seeds = None
+ self.subseeds = None
+
+ self.step_multiplier = 1
+ self.cached_uc = [None, None]
+ self.cached_c = [None, None]
+ self.uc = None
+ self.c = None
@property
def sd_model(self):
@@ -271,11 +281,15 @@ class StableDiffusionProcessing:
def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
self.sampler = None
+ self.c = None
+ self.uc = None
+ self.cached_c = [None, None]
+ self.cached_uc = [None, None]
def get_token_merging_ratio(self, for_hr=False):
if for_hr:
@@ -283,6 +297,52 @@ class StableDiffusionProcessing:
return self.token_merging_ratio or opts.token_merging_ratio
+ def setup_prompts(self):
+ if type(self.prompt) == list:
+ self.all_prompts = self.prompt
+ else:
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
+
+ if type(self.negative_prompt) == list:
+ self.all_negative_prompts = self.negative_prompt
+ else:
+ self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+
+ def get_conds_with_caching(self, function, required_prompts, steps, cache):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+ """
+
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
+ def setup_conds(self):
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
+ self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
+
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c)
+
+ def parse_extra_network_prompts(self):
+ self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
+
+ return extra_network_data
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -582,29 +642,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
comments = {}
- if type(p.prompt) == list:
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
- else:
- p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
- else:
- p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr and p.hr_prompt == '':
- p.all_hr_prompts, p.all_hr_negative_prompts = p.all_prompts, p.all_negative_prompts
- elif p.enable_hr and p.hr_prompt != '':
- if type(p.prompt) == list:
- p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt]
- else:
- p.all_hr_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.hr_negative_prompt]
- else:
- p.all_hr_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)]
+ p.setup_prompts()
if type(seed) == list:
p.all_seeds = seed
@@ -628,29 +666,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
- cached_uc = [None, None]
- cached_c = [None, None]
-
- def get_conds_with_caching(function, required_prompts, steps, cache):
- """
- Returns the result of calling function(shared.sd_model, required_prompts, steps)
- using a cache to store the result if the same arguments have been used before.
-
- cache is an array containing two elements. The first element is a tuple
- representing the previously used arguments, or None if no arguments
- have been used before. The second element is where the previously
- computed result is stored.
- """
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
- return cache[1]
-
- with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
-
- cache[0] = (required_prompts, steps)
- return cache[1]
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -672,40 +687,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr:
- if p.hr_prompt == '':
- hr_prompts, hr_negative_prompts = prompts, negative_prompts
- else:
- hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- hr_negative_prompts = p.all_hr_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-
- seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if p.scripts is not None:
- p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
- if len(prompts) == 0:
+ if len(p.prompts) == 0:
break
- prompts, extra_network_data = extra_networks.parse_prompts(prompts)
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr and hr_prompts != prompts:
- _, hr_extra_network_data = extra_networks.parse_prompts(hr_prompts)
- extra_network_data.update(hr_extra_network_data)
-
+ extra_network_data = p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
extra_networks.activate(p, extra_network_data)
if p.scripts is not None:
- p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
@@ -716,18 +716,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
- step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr:
- if prompts != hr_prompts:
- hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, p.steps, cached_uc)
- hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, p.steps, cached_c)
- else:
- hr_uc, hr_c = uc, c
+ p.setup_conds()
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -736,15 +725,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
-
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_unconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- else:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- else:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
@@ -771,7 +753,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -788,13 +770,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@@ -807,10 +789,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
@@ -879,7 +861,7 @@ def old_hires_fix_first_pass_dimensions(width, height):
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler: str = '---', hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -890,9 +872,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
- self.hr_sampler = hr_sampler
- self.hr_prompt = hr_prompt if hr_prompt != '' else ''
- self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else ''
+ self.hr_sampler_name = hr_sampler_name
+ self.hr_prompt = hr_prompt
+ self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
@@ -906,14 +888,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = 0
self.applied_old_hires_behavior_to = None
+ self.hr_prompts = None
+ self.hr_negative_prompts = None
+ self.hr_extra_network_data = None
+
+ self.hr_c = None
+ self.hr_uc = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if self.hr_sampler != '---':
- self.extra_generation_params["Hires sampler"] = self.hr_sampler
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
+
+ if tuple(self.hr_prompt) != tuple(self.prompt):
+ self.extra_generation_params["Hires prompt"] = self.hr_prompt
- if self.hr_prompt != '':
- self.extra_generation_params["Hires prompt"] = f'({self.hr_prompt.replace(",", ";")})'
- self.extra_generation_params["Hires negative prompt"] = f'({self.hr_negative_prompt.replace(",", ";")})'
+ if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
+ self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
@@ -975,7 +966,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
@@ -1044,16 +1035,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.sampler_name
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
- if self.hr_sampler == '---':
- pass
- else:
- img2img_sampler_name = self.hr_sampler
-
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
@@ -1064,9 +1050,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
+ if not self.disable_extra_networks:
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
- samples = self.sampler.sample_img2img(self, samples, noise, hr_conditioning, hr_unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
@@ -1074,6 +1064,53 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
+ def close(self):
+ self.hr_c = None
+ self.hr_uc = None
+
+ def setup_prompts(self):
+ super().setup_prompts()
+
+ if not self.enable_hr:
+ return
+
+ if self.hr_prompt == '':
+ self.hr_prompt = self.prompt
+
+ if self.hr_negative_prompt == '':
+ self.hr_negative_prompt = self.negative_prompt
+
+ if type(self.hr_prompt) == list:
+ self.all_hr_prompts = self.hr_prompt
+ else:
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
+
+ if type(self.hr_negative_prompt) == list:
+ self.all_hr_negative_prompts = self.hr_negative_prompt
+ else:
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
+
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+
+ def setup_conds(self):
+ super().setup_conds()
+
+ if self.enable_hr:
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c)
+
+ def parse_extra_network_prompts(self):
+ res = super().parse_extra_network_prompts()
+
+ if self.enable_hr:
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
+
+ return res
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None