aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhentailord85ez <112723046+hentailord85ez@users.noreply.github.com>2022-10-11 19:48:53 +0100
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-12 11:38:41 +0300
commit80f3cf2bb2ce3f00d801cae2c3a8c20a8d4167d8 (patch)
tree26d2265074723e9611a3102c84d95949c0d2a326
parentee015a1af66a94a75c914659fa0d321e702a0a87 (diff)
Account when lines are mismatched
-rw-r--r--modules/sd_hijack.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index ac70f876..2753d4fa 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
- z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
+ tokens = []
+ multipliers = []
+ for i in range(len(remade_batch_tokens)):
+ if len(remade_batch_tokens[i]) > 0:
+ tokens.append(remade_batch_tokens[i][:75])
+ multipliers.append(batch_multipliers[i][:75])
+ else:
+ tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
+ multipliers.append([1.0] * 75)
+
+ z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens