From 7001bffe0247804793dfabb69ac96d832572ccd0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:43:25 +0300 Subject: fix AND broken for long prompts --- modules/prompt_parser.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules/prompt_parser.py') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..15666073 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) + # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) -- cgit v1.2.1 From a5550f0213c3f145b1c984816ebcef92c48853ee Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Wed, 5 Oct 2022 19:10:39 +0300 Subject: alternate prompt --- modules/prompt_parser.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules/prompt_parser.py') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 15666073..919d5d31 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -13,13 +13,14 @@ import lark schedule_parser = lark.Lark(r""" !start: (prompt | /[][():]/+)* -prompt: (emphasized | scheduled | plain | WHITESPACE)* +prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* !emphasized: "(" prompt ")" | "(" prompt ":" prompt ")" | "[" prompt "]" scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" +alternate: "[" prompt ("|" prompt)+ "]" WHITESPACE: /\s+/ -plain: /([^\\\[\]():]|\\.)+/ +plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER """) @@ -59,6 +60,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): tree.children[-1] *= steps tree.children[-1] = min(steps, int(tree.children[-1])) l.append(tree.children[-1]) + def alternate(self, tree): + l.extend(range(1, steps+1)) CollectSteps().visit(tree) return sorted(set(l)) @@ -67,6 +70,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): def scheduled(self, args): before, after, _, when = args yield before or () if step <= when else after + def alternate(self, args): + yield next(args[(step - 1)%len(args)]) def start(self, args): def flatten(x): if type(x) == str: -- cgit v1.2.1