aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r--modules/sd_samplers.py44
1 files changed, 28 insertions, 16 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index f5e81f34..df3a6fe8 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,6 +7,7 @@ from PIL import Image
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
+from modules import prompt_parser
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -53,20 +54,6 @@ def store_latent(decoded):
shared.state.current_image = sample_to_image(decoded)
-def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
- if sampler_wrapper.mask is not None:
- img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
- x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
-
- res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
-
- if sampler_wrapper.mask is not None:
- store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
- else:
- store_latent(res[1])
-
- return res
-
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
@@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler:
self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
+
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ else:
+ store_latent(res[1])
+
+ self.step += 1
+ return res
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
@@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
- self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
@@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning):
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None
@@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module):
self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
@@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.step += 1
+
return denoised