aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
authorRobert Barron <rubberbaron@robustspread.com>2023-08-09 07:46:30 -0700
committerRobert Barron <rubberbaron@robustspread.com>2023-08-09 10:38:47 -0700
commitd1ba46b6e15d2f2917e7d392955c8c6f988527ba (patch)
tree947c25c3d802aae5693af960f35728c32cea714b /modules/prompt_parser.py
parent5a38a9c0eea7f8c77585fcb97c51bf0e103e706e (diff)
allow first pass and hires pass to use a single prompt to do different prompt editing, hires is 1.0..2.0:
relative time range is [1..2] absolute time range is [steps+1..steps+hire_steps], e.g. with 30 steps and 20 hires steps, '20' is 2/3rds through first pass, and 40 is halfway through hires pass
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py41
1 files changed, 31 insertions, 10 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 32d214e3..e8c41f38 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
-def get_learned_conditioning_prompt_schedules(prompts, steps):
+def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
"""
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
@@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g("[fe|||]male")
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
+ >>> g("a [b:.5] c")
+ [[10, 'a b c']]
+ >>> g("a [b:1.5] c")
+ [[5, 'a c'], [10, 'a b c']]
"""
+ if hires_steps is None or use_old_scheduling:
+ int_offset = 0
+ flt_offset = 0
+ steps = base_steps
+ else:
+ int_offset = base_steps
+ flt_offset = 1.0
+ steps = hires_steps
+
def collect_steps(steps, tree):
res = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
- tree.children[-2] = float(tree.children[-2])
- if tree.children[-2] < 1:
- tree.children[-2] *= steps
- tree.children[-2] = min(steps, int(tree.children[-2]))
- res.append(tree.children[-2])
+ s = tree.children[-2]
+ v = float(s)
+ if use_old_scheduling:
+ v = v*steps if v<1 else v
+ else:
+ if "." in s:
+ v = (v - flt_offset) * steps
+ else:
+ v = (v - int_offset)
+ tree.children[-2] = min(steps, int(v))
+ if tree.children[-2] >= 1:
+ res.append(tree.children[-2])
def alternate(self, tree):
res.extend(range(1, steps+1))
@@ -134,7 +155,7 @@ class SdConditioning(list):
-def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""
res = []
- prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
@@ -229,7 +250,7 @@ class MulticondLearnedConditioning:
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
-def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
+def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
- learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
res = []
for indexes in res_indexes: