aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py58
1 files changed, 42 insertions, 16 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 203ae1ac..334efeef 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -19,14 +19,14 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
-alternate: "[" prompt ("|" prompt)+ "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
+alternate: "[" prompt ("|" [prompt])+ "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
-def get_learned_conditioning_prompt_schedules(prompts, steps):
+def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
"""
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
@@ -53,18 +53,43 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
>>> g("[a|(b:1.1)]")
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
+ >>> g("[fe|]male")
+ [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
+ >>> g("[fe|||]male")
+ [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
+ >>> g("a [b:.5] c")
+ [[10, 'a b c']]
+ >>> g("a [b:1.5] c")
+ [[5, 'a c'], [10, 'a b c']]
"""
+ if hires_steps is None or use_old_scheduling:
+ int_offset = 0
+ flt_offset = 0
+ steps = base_steps
+ else:
+ int_offset = base_steps
+ flt_offset = 1.0
+ steps = hires_steps
+
def collect_steps(steps, tree):
res = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
- tree.children[-1] = float(tree.children[-1])
- if tree.children[-1] < 1:
- tree.children[-1] *= steps
- tree.children[-1] = min(steps, int(tree.children[-1]))
- res.append(tree.children[-1])
+ s = tree.children[-2]
+ v = float(s)
+ if use_old_scheduling:
+ v = v*steps if v<1 else v
+ else:
+ if "." in s:
+ v = (v - flt_offset) * steps
+ else:
+ v = (v - int_offset)
+ tree.children[-2] = min(steps, int(v))
+ if tree.children[-2] >= 1:
+ res.append(tree.children[-2])
def alternate(self, tree):
res.extend(range(1, steps+1))
@@ -75,13 +100,14 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def at_step(step, tree):
class AtStep(lark.Transformer):
def scheduled(self, args):
- before, after, _, when = args
+ before, after, _, when, _ = args
yield before or () if step <= when else after
def alternate(self, args):
- yield next(args[(step - 1)%len(args)])
+ args = ["" if not arg else arg for arg in args]
+ yield args[(step - 1) % len(args)]
def start(self, args):
def flatten(x):
- if type(x) == str:
+ if isinstance(x, str):
yield x
else:
for gen in x:
@@ -129,7 +155,7 @@ class SdConditioning(list):
-def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -149,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""
res = []
- prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
@@ -224,7 +250,7 @@ class MulticondLearnedConditioning:
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
-def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
+def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -233,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
- learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
res = []
for indexes in res_indexes:
@@ -333,7 +359,7 @@ re_attention = re.compile(r"""
\\|
\(|
\[|
-:([+-]?[.\d]+)\)|
+:\s*([+-]?[.\d]+)\s*\)|
\)|
]|
[^\\()\[\]:]+|