aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py249
1 files changed, 196 insertions, 53 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 5d58c4ed..f00256f2 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,10 +1,7 @@
import re
from collections import namedtuple
-import torch
-from lark import Lark, Transformer, Visitor
-import functools
-
-import modules.shared as shared
+from typing import List
+import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
@@ -14,25 +11,48 @@ import modules.shared as shared
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
+schedule_parser = lark.Lark(r"""
+!start: (prompt | /[][():]/+)*
+prompt: (emphasized | scheduled | plain | WHITESPACE)*
+!emphasized: "(" prompt ")"
+ | "(" prompt ":" prompt ")"
+ | "[" prompt "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
+WHITESPACE: /\s+/
+plain: /([^\\\[\]():]|\\.)+/
+%import common.SIGNED_NUMBER -> NUMBER
+""")
def get_learned_conditioning_prompt_schedules(prompts, steps):
- grammar = r"""
- start: prompt
- prompt: (emphasized | scheduled | weighted | plain)*
- !emphasized: "(" prompt ")"
- | "(" prompt ":" prompt ")"
- | "[" prompt "]"
- scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
- !weighted: "{" weighted_item ("|" weighted_item)* "}"
- !weighted_item: prompt (":" prompt)?
- plain: /([^\\\[\](){}:|]|\\.)+/
- %import common.SIGNED_NUMBER -> NUMBER
"""
- parser = Lark(grammar, parser='lalr')
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
+ >>> g("test")
+ [[10, 'test']]
+ >>> g("a [b:3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [b: 3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [[[b]]:2]")
+ [[2, 'a '], [10, 'a [[b]]']]
+ >>> g("[(a:2):3]")
+ [[3, ''], [10, '(a:2)']]
+ >>> g("a [b : c : 1] d")
+ [[1, 'a b d'], [10, 'a c d']]
+ >>> g("a[b:[c:d:2]:1]e")
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
+ >>> g("a [unbalanced")
+ [[10, 'a [unbalanced']]
+ >>> g("a [b:.5] c")
+ [[5, 'a c'], [10, 'a b c']]
+ >>> g("a [{b|d{:.5] c") # not handling this right now
+ [[5, 'a c'], [10, 'a {b|d{ c']]
+ >>> g("((a][:b:c [d:3]")
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
+ """
def collect_steps(steps, tree):
l = [steps]
- class CollectSteps(Visitor):
+ class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
@@ -43,13 +63,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
return sorted(set(l))
def at_step(step, tree):
- class AtStep(Transformer):
+ class AtStep(lark.Transformer):
def scheduled(self, args):
- if len(args) == 2:
- before, after, when = (), *args
- else:
- before, after, when = args
- yield before if step <= when else after
+ before, after, _, when = args
+ yield before or () if step <= when else after
def start(self, args):
def flatten(x):
if type(x) == str:
@@ -57,16 +74,22 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
else:
for gen in x:
yield from flatten(gen)
- return ''.join(flatten(args[0]))
+ return ''.join(flatten(args))
def plain(self, args):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
yield from child
return AtStep().transform(tree)
-
+
def get_schedule(prompt):
- tree = parser.parse(prompt)
+ try:
+ tree = schedule_parser.parse(prompt)
+ except lark.exceptions.LarkError as e:
+ if 0:
+ import traceback
+ traceback.print_exc()
+ return [[steps, prompt]]
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
@@ -74,11 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
-def get_learned_conditioning(prompts, steps):
+def get_learned_conditioning(model, prompts, steps):
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
+ and the sampling step at which this condition is to be replaced by the next one.
+ Input:
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
+
+ Output:
+ [
+ [
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
+ ],
+ [
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
+ ]
+ ]
+ """
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
@@ -92,7 +130,7 @@ def get_learned_conditioning(prompts, steps):
continue
texts = [x[1] for x in prompt_schedule]
- conds = shared.sd_model.get_learned_conditioning(texts)
+ conds = model.get_learned_conditioning(texts)
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
@@ -101,22 +139,109 @@ def get_learned_conditioning(prompts, steps):
cache[prompt] = cond_schedule
res.append(cond_schedule)
- return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
+ return res
+
+
+re_AND = re.compile(r"\bAND\b")
+re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
+
+def get_multicond_prompt_list(prompts):
+ res_indexes = []
+
+ prompt_flat_list = []
+ prompt_indexes = {}
+
+ for prompt in prompts:
+ subprompts = re_AND.split(prompt)
+
+ indexes = []
+ for subprompt in subprompts:
+ match = re_weight.search(subprompt)
+
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
+
+ weight = float(weight) if weight is not None else 1.0
+
+ index = prompt_indexes.get(text, None)
+ if index is None:
+ index = len(prompt_flat_list)
+ prompt_flat_list.append(text)
+ prompt_indexes[text] = index
+
+ indexes.append((index, weight))
+
+ res_indexes.append(indexes)
+
+ return res_indexes, prompt_flat_list, prompt_indexes
+
+
+class ComposableScheduledPromptConditioning:
+ def __init__(self, schedules, weight=1.0):
+ self.schedules: List[ScheduledPromptConditioning] = schedules
+ self.weight: float = weight
+
+
+class MulticondLearnedConditioning:
+ def __init__(self, shape, batch):
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
-def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
- res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
- for i, cond_schedule in enumerate(c.schedules):
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
+ """
+
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
+
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
+
+ res = []
+ for indexes in res_indexes:
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
+
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+
+
+def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+ param = c[0][0].cond
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ for i, cond_schedule in enumerate(c):
target_index = 0
- for curret_index, (end_at, cond) in enumerate(cond_schedule):
+ for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at:
- target_index = curret_index
+ target_index = current
break
res[i] = cond_schedule[target_index].cond
return res
+def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
+ param = c.batch[0][0].schedules[0].cond
+
+ tensors = []
+ conds_list = []
+
+ for batch_no, composable_prompts in enumerate(c.batch):
+ conds_for_batch = []
+
+ for cond_index, composable_prompt in enumerate(composable_prompts):
+ target_index = 0
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
+ if current_step <= end_at:
+ target_index = current
+ break
+
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
+ tensors.append(composable_prompt.schedules[target_index].cond)
+
+ conds_list.append(conds_for_batch)
+
+ return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+
+
re_attention = re.compile(r"""
\\\(|
\\\)|
@@ -148,23 +273,26 @@ def parse_prompt_attention(text):
\\ - literal character '\'
anything else - just text
- Example:
-
- 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
-
- produces:
-
- [
- ['a ', 1.0],
- ['house', 1.5730000000000004],
- [' ', 1.1],
- ['on', 1.0],
- [' a ', 1.1],
- ['hill', 0.55],
- [', sun, ', 1.1],
- ['sky', 1.4641000000000006],
- ['.', 1.1]
- ]
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
"""
res = []
@@ -206,4 +334,19 @@ def parse_prompt_attention(text):
if len(res) == 0:
res = [["", 1.0]]
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
return res
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
+else:
+ import torch # doctest faster