From 617c5b486f42aa73062ee7699ee1147eb995c899 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 13:23:25 +0300 Subject: make it possible for StableDiffusionProcessing to accept multiple different negative prompts in a batch --- modules/processing.py | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 604d822a..bc7e5311 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -124,6 +124,7 @@ class StableDiffusionProcessing(): self.scripts = None self.script_args = None self.all_prompts = None + self.all_negative_prompts = None self.all_seeds = None self.all_subseeds = None @@ -202,7 +203,7 @@ class StableDiffusionProcessing(): class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -241,16 +242,18 @@ class Processed: self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning - self.all_prompts = all_prompts or [self.prompt] - self.all_seeds = all_seeds or [self.seed] - self.all_subseeds = all_subseeds or [self.subseed] + self.all_prompts = all_prompts or p.all_prompts or [self.prompt] + self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt] + self.all_seeds = all_seeds or p.all_seeds or [self.seed] + self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.infotexts = infotexts or [info] def js(self): obj = { - "prompt": self.prompt, + "prompt": self.all_prompts[0], "all_prompts": self.all_prompts, - "negative_prompt": self.negative_prompt, + "negative_prompt": self.all_negative_prompts[0], + "all_negative_prompts": self.all_negative_prompts, "seed": self.seed, "all_seeds": self.all_seeds, "subseed": self.subseed, @@ -411,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" + negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else "" return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() @@ -440,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: assert p.prompt is not None - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - devices.torch_gc() seed = get_fixed_seed(p.seed) @@ -453,15 +452,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: modules.sd_hijack.model_hijack.clear_comments() comments = {} - prompt_tmp = p.prompt - negative_prompt_tmp = p.negative_prompt - - shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - p.all_prompts = p.prompt + p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt] + else: + p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)] + + if type(p.negative_prompt) == list: + p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt] else: - p.all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)] if type(seed) == list: p.all_seeds = seed @@ -476,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0): return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -500,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: break prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -510,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -596,14 +601,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) if p.scripts is not None: p.scripts.postprocess(p, res) - p.prompt = prompt_tmp - p.negative_prompt = negative_prompt_tmp - return res -- cgit v1.2.1