aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-08 15:43:25 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-10-08 15:43:25 +0300
commit7001bffe0247804793dfabb69ac96d832572ccd0 (patch)
tree74db8920aff51b8dbb01cb801267c84bca2af162
parent77f4237d1c3af1756e7dab2699e3dcebad5619d6 (diff)
fix AND broken for long prompts
-rw-r--r--modules/prompt_parser.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index f00256f2..15666073 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)