aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-11 11:14:36 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-10-11 11:14:36 +0300
commit5de806184f6687e46cf936b92055146dc6cf2994 (patch)
treed84c2daa8798c3d2f8e99e17234a40065491182d /modules/prompt_parser.py
parent12c4d5c6b5bf9dd50d0601c36af4f99b65316d58 (diff)
parent948533950c9db5069a874d925fadd50bac00fdb5 (diff)
Merge branch 'master' into hypernetwork-training
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py18
1 files changed, 16 insertions, 2 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index f00256f2..919d5d31 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -13,13 +13,14 @@ import lark
schedule_parser = lark.Lark(r"""
!start: (prompt | /[][():]/+)*
-prompt: (emphasized | scheduled | plain | WHITESPACE)*
+prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
+alternate: "[" prompt ("|" prompt)+ "]"
WHITESPACE: /\s+/
-plain: /([^\\\[\]():]|\\.)+/
+plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
@@ -59,6 +60,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1])
+ def alternate(self, tree):
+ l.extend(range(1, steps+1))
CollectSteps().visit(tree)
return sorted(set(l))
@@ -67,6 +70,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def scheduled(self, args):
before, after, _, when = args
yield before or () if step <= when else after
+ def alternate(self, args):
+ yield next(args[(step - 1)%len(args)])
def start(self, args):
def flatten(x):
if type(x) == str:
@@ -239,6 +244,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)