aboutsummaryrefslogtreecommitdiff
path: root/modules/styles.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-14 17:56:21 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-14 17:56:21 +0300
commit9f267af3f7404d8d8a9123e8e1c07a6557eba54d (patch)
tree2e4a0c310f8f78377f295edf7867f2a70b39fa96 /modules/styles.py
parent6153d9d9e9d51708e8f96eb8aaecf168adfcf4b7 (diff)
added a second style field
added the ability to use {prompt} in styles added a button to apply style to textbox rearranged top row for UI
Diffstat (limited to 'modules/styles.py')
-rw-r--r--modules/styles.py96
1 files changed, 57 insertions, 39 deletions
diff --git a/modules/styles.py b/modules/styles.py
index bc7f070f..eeedcd08 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -20,49 +20,67 @@ class PromptStyle(typing.NamedTuple):
negative_prompt: str
-def load_styles(path: str) -> dict[str, PromptStyle]:
- styles = {"None": PromptStyle("None", "", "")}
+def merge_prompts(style_prompt: str, prompt: str) -> str:
+ if "{prompt}" in style_prompt:
+ res = style_prompt.replace("{prompt}", prompt)
+ else:
+ parts = filter(None, (prompt.strip(), style_prompt.strip()))
+ res = ", ".join(parts)
- if os.path.exists(path):
- with open(path, "r", encoding="utf8", newline='') as file:
- reader = csv.DictReader(file)
- for row in reader:
- # Support loading old CSV format with "name, text"-columns
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative_prompt = row.get("negative_prompt", "")
- styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+ return res
- return styles
+def apply_styles_to_prompt(prompt, styles):
+ for style in styles:
+ prompt = merge_prompts(style, prompt)
-def merge_prompts(style_prompt: str, prompt: str) -> str:
- parts = filter(None, (prompt.strip(), style_prompt.strip()))
- return ", ".join(parts)
+ return prompt
-def apply_style(processing: StableDiffusionProcessing, style: PromptStyle) -> None:
- if isinstance(processing.prompt, list):
- processing.prompt = [merge_prompts(style.prompt, p) for p in processing.prompt]
- else:
- processing.prompt = merge_prompts(style.prompt, processing.prompt)
+class StyleDatabase:
+ def __init__(self, path: str):
+ self.no_style = PromptStyle("None", "", "")
+ self.styles = {"None": self.no_style}
- if isinstance(processing.negative_prompt, list):
- processing.negative_prompt = [merge_prompts(style.negative_prompt, p) for p in processing.negative_prompt]
- else:
- processing.negative_prompt = merge_prompts(style.negative_prompt, processing.negative_prompt)
-
-
-def save_styles(path: str, styles: abc.Iterable[PromptStyle]) -> None:
- # Write to temporary file first, so we don't nuke the file if something goes wrong
- fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
- writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for style in styles)
-
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.move(path, path + ".bak")
- shutil.move(temp_path, path)
+ if not os.path.exists(path):
+ return
+
+ with open(path, "r", encoding="utf8", newline='') as file:
+ reader = csv.DictReader(file)
+ for row in reader:
+ # Support loading old CSV format with "name, text"-columns
+ prompt = row["prompt"] if "prompt" in row else row["text"]
+ negative_prompt = row.get("negative_prompt", "")
+ self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+
+ def apply_styles_to_prompt(self, prompt, styles):
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
+
+ def apply_negative_styles_to_prompt(self, prompt, styles):
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
+
+ def apply_styles(self, p: StableDiffusionProcessing) -> None:
+ if isinstance(p.prompt, list):
+ p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
+ else:
+ p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
+
+ if isinstance(p.negative_prompt, list):
+ p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
+ else:
+ p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
+
+ def save_styles(self, path: str) -> None:
+ # Write to temporary file first, so we don't nuke the file if something goes wrong
+ fd, temp_path = tempfile.mkstemp(".csv")
+ with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
+ # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
+ # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+ writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
+ writer.writeheader()
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+
+ # Always keep a backup file around
+ if os.path.exists(path):
+ shutil.move(path, path + ".bak")
+ shutil.move(temp_path, path)