From f2693bec08d2c2e513cb35fa24402396505a01a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 15 Sep 2022 13:10:16 +0300 Subject: prompt editing --- modules/sd_samplers.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) (limited to 'modules/sd_samplers.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7ef507f1..c042c5c3 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,6 +7,7 @@ from PIL import Image import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms +from modules import prompt_parser from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -53,20 +54,6 @@ def store_latent(decoded): shared.state.current_image = sample_to_image(decoded) -def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): - if sampler_wrapper.mask is not None: - img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) - x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec - - res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) - - if sampler_wrapper.mask is not None: - store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1]) - else: - store_latent(res[1]) - - return res - def extended_tdqm(sequence, *args, desc=None, **kwargs): state.sampling_steps = len(sequence) @@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler: self.mask = None self.nmask = None self.init_latent = None + self.step = 0 + + def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + + if self.mask is not None: + img_orig = self.sampler.model.q_sample(self.init_latent, ts) + x_dec = img_orig * self.mask + self.nmask * x_dec + + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + + if self.mask is not None: + store_latent(self.init_latent * self.mask + self.nmask * res[1]) + else: + store_latent(res[1]) + + self.step += 1 + return res def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): t_enc = int(min(p.denoising_strength, 0.999) * p.steps) @@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) + self.sampler.p_sample_ddim = self.p_sample_ddim_hook self.mask = p.mask self.nmask = p.nmask self.init_latent = p.init_latent @@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning): for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)) + setattr(self.sampler, fieldname, self.p_sample_ddim_hook) self.mask = None self.nmask = None self.init_latent = None @@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module): self.mask = None self.nmask = None self.init_latent = None + self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): + cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + if shared.batch_cond_uncond: x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) @@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised + self.step += 1 + return denoised -- cgit v1.2.1