From a9f0e7d53611cf11331e2befd34f0351b47795ee Mon Sep 17 00:00:00 2001 From: invincibledude <> Date: Sun, 22 Jan 2023 15:12:00 +0300 Subject: hr conditioning --- modules/processing.py | 72 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 1133619f..21886bb5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -235,7 +235,7 @@ class StableDiffusionProcessing: def init(self, all_prompts, all_seeds, all_subseeds): pass - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): + def sample(self, conditioning, unconditional_conditioning, hr_conditioning, hr_uconditional_conditioning, seeds, subseeds, subseed_strength, prompts): raise NotImplementedError() def close(self): @@ -516,25 +516,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)] - # if type(p) == StableDiffusionProcessingTxt2Img: - # if p.enable_hr and p.is_hr_pass: - # logging.info("Running hr pass with custom prompt") - # if p.hr_prompt: - # if type(p.prompt) == list: - # p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt] - # else: - # p.all_prompts = p.batch_size * p.n_iter * [ - # shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)] - # logging.info(p.all_prompts) - # - # if p.hr_negative_prompt: - # if type(p.negative_prompt) == list: - # p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in - # p.hr_negative_prompt] - # else: - # p.all_negative_prompts = p.batch_size * p.n_iter * [ - # shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)] - # logging.info(p.all_negative_prompts) + if type(p) == StableDiffusionProcessingTxt2Img: + if p.enable_hr and p.is_hr_pass: + logging.info("Running hr pass with custom prompt") + if p.hr_prompt: + if type(p.prompt) == list: + p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt] + else: + p.all_hr_prompts = p.batch_size * p.n_iter * [ + shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)] + logging.info(p.all_prompts) + + if p.hr_negative_prompt: + if type(p.negative_prompt) == list: + p.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in + p.hr_negative_prompt] + else: + p.all_hr_negative_prompts = p.batch_size * p.n_iter * [ + shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)] + logging.info(p.all_negative_prompts) if type(seed) == list: p.all_seeds = seed @@ -607,6 +607,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + + if type(p) == StableDiffusionProcessingTxt2Img: + if p.enable_hr: + hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size] + hr_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -620,6 +626,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc) c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c) + if type(p) == StableDiffusionProcessingTxt2Img: + if p.enable_hr: + hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, + cached_uc) + hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, + cached_c) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -629,7 +641,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) + if type(p) == StableDiffusionProcessingTxt2Img: + if p.enable_hr: + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_uconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, + subseeds=subseeds, + subseed_strength=p.subseed_strength, prompts=prompts) + else: + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, + subseeds=subseeds, + subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] for x in x_samples_ddim: @@ -744,6 +765,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_sampler = hr_sampler self.hr_prompt = hr_prompt if hr_prompt != '' else self.prompt self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else self.negative_prompt + self.all_hr_prompts = None + self.all_hr_negative_prompts = None if firstphase_width != 0 or firstphase_height != 0: self.hr_upscale_to_x = self.width @@ -817,7 +840,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_upscaler is not None: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): + def sample(self, conditioning, unconditional_conditioning, hr_conditioning, hr_uconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") @@ -830,9 +853,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: return samples - self.prompt = self.hr_prompt - self.negative_prompt = self.hr_negative_prompt - target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y @@ -904,7 +924,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + samples = self.sampler.sample_img2img(self, samples, noise, hr_conditioning, hr_unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) return samples -- cgit v1.2.1