aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorKyle <zerouex@gmail.com>2023-02-02 09:37:01 -0500
committerKyle <zerouex@gmail.com>2023-02-02 09:37:01 -0500
commit269833067de1e7d0b6a6bd65724743d6b88a133f (patch)
tree58f7a2db842febb741fe779d94451ed5be548dea /modules/sd_samplers_kdiffusion.py
parent226d840e84c5f306350b0681945989b86760e616 (diff)
instruct-pix2pix support
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index aa7f106b..31ee22d3 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -77,9 +77,9 @@ class CFGDenoiser(torch.nn.Module):
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
@@ -88,7 +88,7 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
- cond_in = torch.cat([tensor, uncond])
+ cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})