aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-22 00:13:53 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-22 00:13:53 +0300
commit3366e494a1147e570d8527eea19da88edb3a1e0c (patch)
tree62c196b5c13a337f0bf72aca3a07d1c51b7f925f /modules/sd_samplers_kdiffusion.py
parent8faac8b96313c6c4bf0a81bddecff4d6ba22ac25 (diff)
option to pad prompt/neg prompt to be same length
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 59982fc9..638e0ac9 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -125,6 +125,16 @@ class CFGDenoiser(torch.nn.Module):
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
+ # TODO add infotext entry
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ elif num_repeats > 0:
+ uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])