aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 08:01:38 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 08:01:38 +0300
commitef1698fd6dbd6387341a1eeeded068ff1476ee50 (patch)
treeddaa0cf76e8cf95b93f63909a026ae3d5eab460a /modules/prompt_parser.py
parent0fae47e97445df4e7de4d85538a80917fc2a2457 (diff)
parentc613416af375092f55b9bc8649c949e95d250c44 (diff)
Merge branch 'dev' into extra-networks-always-visible
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py122
1 files changed, 95 insertions, 27 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 0069d8b0..32d214e3 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import re
from collections import namedtuple
from typing import List
@@ -17,8 +19,8 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
-alternate: "[" prompt ("|" prompt)+ "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
+alternate: "[" prompt ("|" [prompt])+ "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
@@ -51,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
>>> g("[a|(b:1.1)]")
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
+ >>> g("[fe|]male")
+ [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
+ >>> g("[fe|||]male")
+ [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
"""
def collect_steps(steps, tree):
@@ -58,11 +64,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
- tree.children[-1] = float(tree.children[-1])
- if tree.children[-1] < 1:
- tree.children[-1] *= steps
- tree.children[-1] = min(steps, int(tree.children[-1]))
- res.append(tree.children[-1])
+ tree.children[-2] = float(tree.children[-2])
+ if tree.children[-2] < 1:
+ tree.children[-2] *= steps
+ tree.children[-2] = min(steps, int(tree.children[-2]))
+ res.append(tree.children[-2])
def alternate(self, tree):
res.extend(range(1, steps+1))
@@ -73,10 +79,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def at_step(step, tree):
class AtStep(lark.Transformer):
def scheduled(self, args):
- before, after, _, when = args
+ before, after, _, when, _ = args
yield before or () if step <= when else after
def alternate(self, args):
- yield next(args[(step - 1)%len(args)])
+ args = ["" if not arg else arg for arg in args]
+ yield args[(step - 1) % len(args)]
def start(self, args):
def flatten(x):
if type(x) == str:
@@ -109,7 +116,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-def get_learned_conditioning(model, prompts, steps):
+class SdConditioning(list):
+ """
+ A list with prompts for stable diffusion's conditioner model.
+ Can also specify width and height of created image - SDXL needs it.
+ """
+ def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
+ super().__init__()
+ self.extend(prompts)
+
+ if copy_from is None:
+ copy_from = prompts
+
+ self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
+ self.width = width or getattr(copy_from, 'width', None)
+ self.height = height or getattr(copy_from, 'height', None)
+
+
+
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -139,12 +164,17 @@ def get_learned_conditioning(model, prompts, steps):
res.append(cached)
continue
- texts = [x[1] for x in prompt_schedule]
+ texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
conds = model.get_learned_conditioning(texts)
cond_schedule = []
for i, (end_at_step, _) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+ if isinstance(conds, dict):
+ cond = {k: v[i] for k, v in conds.items()}
+ else:
+ cond = conds[i]
+
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
cache[prompt] = cond_schedule
res.append(cond_schedule)
@@ -153,13 +183,15 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b")
-re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
+re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-def get_multicond_prompt_list(prompts):
+
+def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
- prompt_flat_list = []
prompt_indexes = {}
+ prompt_flat_list = SdConditioning(prompts)
+ prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -196,6 +228,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -214,20 +247,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+class DictWithShape(dict):
+ def __init__(self, x, shape):
+ super().__init__()
+ self.update(x)
+
+ @property
+ def shape(self):
+ return self["crossattn"].shape
+
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ is_dict = isinstance(param, dict)
+
+ if is_dict:
+ dict_cond = param
+ res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
+ res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
+ else:
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+
for i, cond_schedule in enumerate(c):
target_index = 0
for current, entry in enumerate(cond_schedule):
if current_step <= entry.end_at_step:
target_index = current
break
- res[i] = cond_schedule[target_index].cond
+
+ if is_dict:
+ for k, param in cond_schedule[target_index].cond.items():
+ res[k][i] = param
+ else:
+ res[i] = cond_schedule[target_index].cond
return res
+def stack_conds(tensors):
+ # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+ return torch.stack(tensors)
+
+
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
@@ -249,16 +319,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
- # and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+ if isinstance(tensors[0], dict):
+ keys = list(tensors[0].keys())
+ stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
+ stacked = DictWithShape(stacked, stacked['crossattn'].shape)
+ else:
+ stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+ return conds_list, stacked
re_attention = re.compile(r"""
@@ -270,7 +338,7 @@ re_attention = re.compile(r"""
\\|
\(|
\[|
-:([+-]?[.\d]+)\)|
+:\s*([+-]?[.\d]+)\s*\)|
\)|
]|
[^\\()\[\]:]+|