aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorw-e-w <40751091+w-e-w@users.noreply.github.com>2024-01-21 02:21:36 +0900
committerw-e-w <40751091+w-e-w@users.noreply.github.com>2024-01-21 03:53:58 +0900
commit25e8273d2f6481deca221b29d35093f6d0c9da6a (patch)
tree8aa4651858aa851811341778a3e41092aab46f61 /modules
parentf939bce845ae07536b1c920618743af83e0b01ec (diff)
re-work multi --styles-file
--styles-file change to append str --styles-file is [] then defaults to [styles.csv] --styles-file accepts paths or paths with wildcard "*" the first `--styles-file` entry is use as the default styles file path if filename a wildcard then the first matching file is used if no match is found, create a new "styles.csv" in the same dir as the first path when saving a new style it will be save in the default styles file when saving a existing style, it will be saved to file it belongs to order of the styles files in the styles dropdown can be controlled to a certain degree by the order of --styles-file
Diffstat (limited to 'modules')
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/shared.py3
-rw-r--r--modules/styles.py85
-rw-r--r--modules/ui_prompt_styles.py9
4 files changed, 52 insertions, 47 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index e58059a1..f1251b6c 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -88,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
-parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
+parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
diff --git a/modules/shared.py b/modules/shared.py
index 63661939..ccdca4e7 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,3 +1,4 @@
+import os
import sys
import gradio as gr
@@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
parallel_processing_allowed = True
-styles_filename = cmd_opts.styles_file
+styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
diff --git a/modules/styles.py b/modules/styles.py
index 026c4300..9edcc7e4 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,16 +1,15 @@
+from pathlib import Path
import csv
-import fnmatch
import os
-import os.path
import typing
import shutil
class PromptStyle(typing.NamedTuple):
name: str
- prompt: str
- negative_prompt: str
- path: str = None
+ prompt: str | None
+ negative_prompt: str | None
+ path: str | None = None
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
- def __init__(self, path: str):
+ def __init__(self, paths: list[str | Path]):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
- self.path = path
-
- folder, file = os.path.split(self.path)
- filename, _, ext = file.partition('*')
- self.default_path = os.path.join(folder, filename + ext)
+ self.paths = paths
+ self.all_styles_files: list[Path] = []
+
+ folder, file = os.path.split(self.paths[0])
+ if '*' in file or '?' in file:
+ # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
+ self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
+ self.paths.insert(0, self.default_path)
+ else:
+ self.default_path = Path(self.paths[0])
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -99,33 +103,31 @@ class StyleDatabase:
"""
self.styles.clear()
- path, filename = os.path.split(self.path)
-
- if "*" in filename:
- fileglob = filename.split("*")[0] + "*.csv"
- filelist = []
- for file in os.listdir(path):
- if fnmatch.fnmatch(file, fileglob):
- filelist.append(file)
- # Add a visible divider to the style list
- half_len = round(len(file) / 2)
- divider = f"{'-' * (20 - half_len)} {file.upper()}"
- divider = f"{divider} {'-' * (40 - len(divider))}"
- self.styles[divider] = PromptStyle(
- f"{divider}", None, None, "do_not_save"
- )
- # Add styles from this CSV file
- self.load_from_csv(os.path.join(path, file))
- if len(filelist) == 0:
- print(f"No styles found in {path} matching {fileglob}")
- return
- elif not os.path.exists(self.path):
- print(f"Style database not found: {self.path}")
- return
- else:
- self.load_from_csv(self.path)
-
- def load_from_csv(self, path: str):
+ # scans for all styles files
+ all_styles_files = []
+ for pattern in self.paths:
+ folder, file = os.path.split(pattern)
+ if '*' in file or '?' in file:
+ found_files = Path(folder).glob(file)
+ [all_styles_files.append(file) for file in found_files]
+ else:
+ # if os.path.exists(pattern):
+ all_styles_files.append(Path(pattern))
+
+ # Remove any duplicate entries
+ seen = set()
+ self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
+
+ for styles_file in self.all_styles_files:
+ if len(all_styles_files) > 1:
+ # add divider when more than styles file
+ # '---------------- STYLES ----------------'
+ divider = f' {styles_file.stem.upper()} '.center(40, '-')
+ self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
+ if styles_file.is_file():
+ self.load_from_csv(styles_file)
+
+ def load_from_csv(self, path: str | Path):
with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
@@ -137,7 +139,7 @@ class StyleDatabase:
negative_prompt = row.get("negative_prompt", "")
# Add style to database
self.styles[row["name"]] = PromptStyle(
- row["name"], prompt, negative_prompt, path
+ row["name"], prompt, negative_prompt, str(path)
)
def get_style_paths(self) -> set:
@@ -145,11 +147,11 @@ class StyleDatabase:
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
+ self.styles[style.name] = style._replace(path=str(self.default_path))
# Create a list of all distinct paths, including the default path
style_paths = set()
- style_paths.add(self.default_path)
+ style_paths.add(str(self.default_path))
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
@@ -177,7 +179,6 @@ class StyleDatabase:
def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
- _ = path
style_paths = self.get_style_paths()
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index 0d74c23f..d67e3f17 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
if not name:
return gr.update(visible=False)
- style = styles.PromptStyle(name, prompt, negative_prompt)
+ existing_style = shared.prompt_styles.styles.get(name)
+ path = existing_style.path if existing_style is not None else None
+
+ style = styles.PromptStyle(name, prompt, negative_prompt, path)
shared.prompt_styles.styles[style.name] = style
- shared.prompt_styles.save_styles(shared.styles_filename)
+ shared.prompt_styles.save_styles()
return gr.update(visible=True)
@@ -34,7 +37,7 @@ def delete_style(name):
return
shared.prompt_styles.styles.pop(name, None)
- shared.prompt_styles.save_styles(shared.styles_filename)
+ shared.prompt_styles.save_styles()
return '', '', ''