aboutsummaryrefslogtreecommitdiff
path: root/modules/styles.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/styles.py')
-rw-r--r--modules/styles.py158
1 files changed, 75 insertions, 83 deletions
diff --git a/modules/styles.py b/modules/styles.py
index 81d9800d..60bd8a7f 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,16 +1,16 @@
+from pathlib import Path
+from modules import errors
import csv
-import fnmatch
import os
-import os.path
import typing
import shutil
class PromptStyle(typing.NamedTuple):
name: str
- prompt: str
- negative_prompt: str
- path: str = None
+ prompt: str | None
+ negative_prompt: str | None
+ path: str | None = None
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -30,38 +30,29 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
-def unwrap_style_text_from_prompt(style_text, prompt):
- """
- Checks the prompt to see if the style text is wrapped around it. If so,
- returns True plus the prompt text without the style text. Otherwise, returns
- False with the original prompt.
+def extract_style_text_from_prompt(style_text, prompt):
+ """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
- Note that the "cleaned" version of the style text is only used for matching
- purposes here. It isn't returned; the original style text is not modified.
+ extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
"""
- stripped_prompt = prompt
- stripped_style_text = style_text
+
+ stripped_prompt = prompt.strip()
+ stripped_style_text = style_text.strip()
+
if "{prompt}" in stripped_style_text:
- # Work out whether the prompt is wrapped in the style text. If so, we
- # return True and the "inner" prompt text that isn't part of the style.
- try:
- left, right = stripped_style_text.split("{prompt}", 2)
- except ValueError as e:
- # If the style text has multple "{prompt}"s, we can't split it into
- # two parts. This is an error, but we can't do anything about it.
- print(f"Unable to compare style text to prompt:\n{style_text}")
- print(f"Error: {e}")
- return False, prompt
+ left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
- prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
return True, prompt
else:
- # Work out whether the given prompt ends with the style text. If so, we
- # return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
- prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
- if prompt.endswith(", "):
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
+
+ if prompt.endswith(', '):
prompt = prompt[:-2]
+
return True, prompt
return False, prompt
@@ -76,15 +67,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
- match_positive, extracted_positive = unwrap_style_text_from_prompt(
- style.prompt, prompt
- )
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
if not match_positive:
return False, prompt, negative_prompt
- match_negative, extracted_negative = unwrap_style_text_from_prompt(
- style.negative_prompt, negative_prompt
- )
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
if not match_negative:
return False, prompt, negative_prompt
@@ -92,14 +79,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
- def __init__(self, path: str):
+ def __init__(self, paths: list[str | Path]):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
- self.path = path
-
- folder, file = os.path.split(self.path)
- filename, _, ext = file.partition('*')
- self.default_path = os.path.join(folder, filename + ext)
+ self.paths = paths
+ self.all_styles_files: list[Path] = []
+
+ folder, file = os.path.split(self.paths[0])
+ if '*' in file or '?' in file:
+ # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
+ self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
+ self.paths.insert(0, self.default_path)
+ else:
+ self.default_path = Path(self.paths[0])
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -112,57 +104,58 @@ class StyleDatabase:
"""
self.styles.clear()
- path, filename = os.path.split(self.path)
-
- if "*" in filename:
- fileglob = filename.split("*")[0] + "*.csv"
- filelist = []
- for file in os.listdir(path):
- if fnmatch.fnmatch(file, fileglob):
- filelist.append(file)
- # Add a visible divider to the style list
- half_len = round(len(file) / 2)
- divider = f"{'-' * (20 - half_len)} {file.upper()}"
- divider = f"{divider} {'-' * (40 - len(divider))}"
- self.styles[divider] = PromptStyle(
- f"{divider}", None, None, "do_not_save"
+ # scans for all styles files
+ all_styles_files = []
+ for pattern in self.paths:
+ folder, file = os.path.split(pattern)
+ if '*' in file or '?' in file:
+ found_files = Path(folder).glob(file)
+ [all_styles_files.append(file) for file in found_files]
+ else:
+ # if os.path.exists(pattern):
+ all_styles_files.append(Path(pattern))
+
+ # Remove any duplicate entries
+ seen = set()
+ self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
+
+ for styles_file in self.all_styles_files:
+ if len(all_styles_files) > 1:
+ # add divider when more than styles file
+ # '---------------- STYLES ----------------'
+ divider = f' {styles_file.stem.upper()} '.center(40, '-')
+ self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
+ if styles_file.is_file():
+ self.load_from_csv(styles_file)
+
+ def load_from_csv(self, path: str | Path):
+ try:
+ with open(path, "r", encoding="utf-8-sig", newline="") as file:
+ reader = csv.DictReader(file, skipinitialspace=True)
+ for row in reader:
+ # Ignore empty rows or rows starting with a comment
+ if not row or row["name"].startswith("#"):
+ continue
+ # Support loading old CSV format with "name, text"-columns
+ prompt = row["prompt"] if "prompt" in row else row["text"]
+ negative_prompt = row.get("negative_prompt", "")
+ # Add style to database
+ self.styles[row["name"]] = PromptStyle(
+ row["name"], prompt, negative_prompt, str(path)
)
- # Add styles from this CSV file
- self.load_from_csv(os.path.join(path, file))
- if len(filelist) == 0:
- print(f"No styles found in {path} matching {fileglob}")
- return
- elif not os.path.exists(self.path):
- print(f"Style database not found: {self.path}")
- return
- else:
- self.load_from_csv(self.path)
-
- def load_from_csv(self, path: str):
- with open(path, "r", encoding="utf-8-sig", newline="") as file:
- reader = csv.DictReader(file, skipinitialspace=True)
- for row in reader:
- # Ignore empty rows or rows starting with a comment
- if not row or row["name"].startswith("#"):
- continue
- # Support loading old CSV format with "name, text"-columns
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative_prompt = row.get("negative_prompt", "")
- # Add style to database
- self.styles[row["name"]] = PromptStyle(
- row["name"], prompt, negative_prompt, path
- )
+ except Exception:
+ errors.report(f'Error loading styles from {path}: ', exc_info=True)
def get_style_paths(self) -> set:
"""Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
+ self.styles[style.name] = style._replace(path=str(self.default_path))
# Create a list of all distinct paths, including the default path
style_paths = set()
- style_paths.add(self.default_path)
+ style_paths.add(str(self.default_path))
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
@@ -190,7 +183,6 @@ class StyleDatabase:
def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
- _ = path
style_paths = self.get_style_paths()