aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index d7f9e9a9..33810669 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import re
from collections import namedtuple
from typing import List
@@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-def get_learned_conditioning(model, prompts, steps):
+class SdConditioning(list):
+ """
+ A list with prompts for stable diffusion's conditioner model.
+ Can also specify width and height of created image - SDXL needs it.
+ """
+ def __init__(self, prompts, width=None, height=None):
+ super().__init__()
+ self.extend(prompts)
+ self.width = width or getattr(prompts, 'width', None)
+ self.height = height or getattr(prompts, 'height', None)
+
+
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-def get_multicond_prompt_list(prompts):
+
+def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
- prompt_flat_list = []
prompt_indexes = {}
+ prompt_flat_list = SdConditioning(prompts)
+ prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -201,6 +217,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.