aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-02 09:53:27 +0300
committerGitHub <noreply@github.com>2023-12-02 09:53:27 +0300
commitef1723ef41f90f615f5b5def651db3a683452157 (patch)
tree3dfd488cab60cec37aaa8d2f302e3f16f22c6815
parent7547d7c7910ac8cd4eb0b1543a7f408ed0947dbe (diff)
parent0cd5b0ed541db496ecb5042d1340303a182fce7b (diff)
Merge pull request #14125 from cjj1977/dev
Allow use of mutiple styles csv files
-rw-r--r--modules/styles.py203
1 files changed, 171 insertions, 32 deletions
diff --git a/modules/styles.py b/modules/styles.py
index 0740fe1b..4d218cd7 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,4 +1,5 @@
import csv
+import fnmatch
import os
import os.path
import re
@@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
+ path: str = None
+
+
+def clean_text(text: str) -> str:
+ """
+ Iterating through a list of regular expressions and replacement strings, we
+ clean up the prompt and style text to make it easier to match against each
+ other.
+ """
+ re_list = [
+ ("multiple commas", re.compile("(,+\s+)+,?"), ", "),
+ ("multiple spaces", re.compile("\s{2,}"), " "),
+ ]
+ for _, regex, replace in re_list:
+ text = regex.sub(replace, text)
+
+ return text.strip(", ")
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles):
for style in styles:
prompt = merge_prompts(style, prompt)
- return prompt
+ return clean_text(prompt)
-re_spaces = re.compile(" +")
+def unwrap_style_text_from_prompt(style_text, prompt):
+ """
+ Checks the prompt to see if the style text is wrapped around it. If so,
+ returns True plus the prompt text without the style text. Otherwise, returns
+ False with the original prompt.
-
-def extract_style_text_from_prompt(style_text, prompt):
- stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
- stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
+ Note that the "cleaned" version of the style text is only used for matching
+ purposes here. It isn't returned; the original style text is not modified.
+ """
+ stripped_prompt = clean_text(prompt)
+ stripped_style_text = clean_text(style_text)
if "{prompt}" in stripped_style_text:
- left, right = stripped_style_text.split("{prompt}", 2)
+ # Work out whether the prompt is wrapped in the style text. If so, we
+ # return True and the "inner" prompt text that isn't part of the style.
+ try:
+ left, right = stripped_style_text.split("{prompt}", 2)
+ except ValueError as e:
+ # If the style text has multple "{prompt}"s, we can't split it into
+ # two parts. This is an error, but we can't do anything about it.
+ print(f"Unable to compare style text to prompt:\n{style_text}")
+ print(f"Error: {e}")
+ return False, prompt
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
- prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
+ prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
else:
+ # Work out whether the given prompt ends with the style text. If so, we
+ # return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
- prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
-
- if prompt.endswith(', '):
+ prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
+ if prompt.endswith(", "):
prompt = prompt[:-2]
-
return True, prompt
return False, prompt
-def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
+def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
+ """
+ Takes a style and compares it to the prompt and negative prompt. If the style
+ matches, returns True plus the prompt and negative prompt with the style text
+ removed. Otherwise, returns False with the original prompt and negative prompt.
+ """
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
- match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
+ match_positive, extracted_positive = unwrap_style_text_from_prompt(
+ style.prompt, prompt
+ )
if not match_positive:
return False, prompt, negative_prompt
- match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
+ match_negative, extracted_negative = unwrap_style_text_from_prompt(
+ style.negative_prompt, negative_prompt
+ )
if not match_negative:
return False, prompt, negative_prompt
@@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
def __init__(self, path: str):
- self.no_style = PromptStyle("None", "", "")
+ self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path
+ folder, file = os.path.split(self.path)
+ self.default_file = file.split("*")[0] + ".csv"
+ if self.default_file == ".csv":
+ self.default_file = "styles.csv"
+ self.default_path = os.path.join(folder, self.default_file)
+
+ self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
+
self.reload()
def reload(self):
+ """
+ Clears the style database and reloads the styles from the CSV file(s)
+ matching the path used to initialize the database.
+ """
self.styles.clear()
- if not os.path.exists(self.path):
+ path, filename = os.path.split(self.path)
+
+ if "*" in filename:
+ fileglob = filename.split("*")[0] + "*.csv"
+ filelist = []
+ for file in os.listdir(path):
+ if fnmatch.fnmatch(file, fileglob):
+ filelist.append(file)
+ # Add a visible divider to the style list
+ half_len = round(len(file) / 2)
+ divider = f"{'-' * (20 - half_len)} {file.upper()}"
+ divider = f"{divider} {'-' * (40 - len(divider))}"
+ self.styles[divider] = PromptStyle(
+ f"{divider}", None, None, "do_not_save"
+ )
+ # Add styles from this CSV file
+ self.load_from_csv(os.path.join(path, file))
+ if len(filelist) == 0:
+ print(f"No styles found in {path} matching {fileglob}")
+ return
+ elif not os.path.exists(self.path):
+ print(f"Style database not found: {self.path}")
return
+ else:
+ self.load_from_csv(self.path)
- with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
+ def load_from_csv(self, path: str):
+ with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
+ # Ignore empty rows or rows starting with a comment
+ if not row or row["name"].startswith("#"):
+ continue
# Support loading old CSV format with "name, text"-columns
prompt = row["prompt"] if "prompt" in row else row["text"]
negative_prompt = row.get("negative_prompt", "")
- self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+ # Add style to database
+ self.styles[row["name"]] = PromptStyle(
+ row["name"], prompt, negative_prompt, path
+ )
+
+ def get_style_paths(self) -> list():
+ """
+ Returns a list of all distinct paths, including the default path, of
+ files that styles are loaded from."""
+ # Update any styles without a path to the default path
+ for style in list(self.styles.values()):
+ if not style.path:
+ self.styles[style.name] = style._replace(path=self.default_path)
+
+ # Create a list of all distinct paths, including the default path
+ style_paths = set()
+ style_paths.add(self.default_path)
+ for _, style in self.styles.items():
+ if style.path:
+ style_paths.add(style.path)
+
+ # Remove any paths for styles that are just list dividers
+ style_paths.remove("do_not_save")
+
+ return list(style_paths)
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -96,20 +200,53 @@ class StyleDatabase:
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
def apply_styles_to_prompt(self, prompt, styles):
- return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
+ return apply_styles_to_prompt(
+ prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
+ )
def apply_negative_styles_to_prompt(self, prompt, styles):
- return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
-
- def save_styles(self, path: str) -> None:
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.copy(path, f"{path}.bak")
-
- with open(path, "w", encoding="utf-8-sig", newline='') as file:
- writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
+ return apply_styles_to_prompt(
+ prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
+ )
+
+ def save_styles(self, path: str = None) -> None:
+ # The path argument is deprecated, but kept for backwards compatibility
+ _ = path
+
+ # Update any styles without a path to the default path
+ for style in list(self.styles.values()):
+ if not style.path:
+ self.styles[style.name] = style._replace(path=self.default_path)
+
+ # Create a list of all distinct paths, including the default path
+ style_paths = set()
+ style_paths.add(self.default_path)
+ for _, style in self.styles.items():
+ if style.path:
+ style_paths.add(style.path)
+
+ # Remove any paths for styles that are just list dividers
+ style_paths.remove("do_not_save")
+
+ csv_names = [os.path.split(path)[1].lower() for path in style_paths]
+
+ for style_path in style_paths:
+ # Always keep a backup file around
+ if os.path.exists(style_path):
+ shutil.copy(style_path, f"{style_path}.bak")
+
+ # Write the styles to the CSV file
+ with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
+ writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
+ writer.writeheader()
+ for style in (s for s in self.styles.values() if s.path == style_path):
+ # Skip style list dividers, e.g. "STYLES.CSV"
+ if style.name.lower().strip("# ") in csv_names:
+ continue
+ # Write style fields, ignoring the path field
+ writer.writerow(
+ {k: v for k, v in style._asdict().items() if k != "path"}
+ )
def extract_styles_from_prompt(self, prompt, negative_prompt):
extracted = []
@@ -120,7 +257,9 @@ class StyleDatabase:
found_style = None
for style in applicable_styles:
- is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
+ is_match, new_prompt, new_neg_prompt = extract_original_prompts(
+ style, prompt, negative_prompt
+ )
if is_match:
found_style = style
prompt = new_prompt