aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-14 17:56:21 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-14 17:56:21 +0300
commit9f267af3f7404d8d8a9123e8e1c07a6557eba54d (patch)
tree2e4a0c310f8f78377f295edf7867f2a70b39fa96 /modules
parent6153d9d9e9d51708e8f96eb8aaecf168adfcf4b7 (diff)
added a second style field
added the ability to use {prompt} in styles added a button to apply style to textbox rearranged top row for UI
Diffstat (limited to 'modules')
-rw-r--r--modules/img2img.py4
-rw-r--r--modules/processing.py6
-rw-r--r--modules/shared.py2
-rw-r--r--modules/styles.py96
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py98
6 files changed, 139 insertions, 71 deletions
diff --git a/modules/img2img.py b/modules/img2img.py
index bfcd7598..40a3499c 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -11,7 +11,7 @@ from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
-def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init_img_with_mask, init_mask, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
+def img2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_mask, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1
is_upscale = mode == 2
@@ -37,7 +37,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- prompt_style=prompt_style,
+ styles=[prompt_style, prompt_style2],
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
diff --git a/modules/processing.py b/modules/processing.py
index ca32c610..38e74fe2 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -46,14 +46,14 @@ def apply_color_correction(correction, image):
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.prompt_style: str = prompt_style
+ self.styles: str = styles
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -182,7 +182,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
comments = []
- modules.styles.apply_style(p, shared.prompt_styles[p.prompt_style])
+ shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
all_prompts = p.prompt
diff --git a/modules/shared.py b/modules/shared.py
index 1206cb4c..03269444 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -81,7 +81,7 @@ state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
styles_filename = cmd_opts.styles_file
-prompt_styles = modules.styles.load_styles(styles_filename)
+prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
diff --git a/modules/styles.py b/modules/styles.py
index bc7f070f..eeedcd08 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -20,49 +20,67 @@ class PromptStyle(typing.NamedTuple):
negative_prompt: str
-def load_styles(path: str) -> dict[str, PromptStyle]:
- styles = {"None": PromptStyle("None", "", "")}
+def merge_prompts(style_prompt: str, prompt: str) -> str:
+ if "{prompt}" in style_prompt:
+ res = style_prompt.replace("{prompt}", prompt)
+ else:
+ parts = filter(None, (prompt.strip(), style_prompt.strip()))
+ res = ", ".join(parts)
- if os.path.exists(path):
- with open(path, "r", encoding="utf8", newline='') as file:
- reader = csv.DictReader(file)
- for row in reader:
- # Support loading old CSV format with "name, text"-columns
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative_prompt = row.get("negative_prompt", "")
- styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+ return res
- return styles
+def apply_styles_to_prompt(prompt, styles):
+ for style in styles:
+ prompt = merge_prompts(style, prompt)
-def merge_prompts(style_prompt: str, prompt: str) -> str:
- parts = filter(None, (prompt.strip(), style_prompt.strip()))
- return ", ".join(parts)
+ return prompt
-def apply_style(processing: StableDiffusionProcessing, style: PromptStyle) -> None:
- if isinstance(processing.prompt, list):
- processing.prompt = [merge_prompts(style.prompt, p) for p in processing.prompt]
- else:
- processing.prompt = merge_prompts(style.prompt, processing.prompt)
+class StyleDatabase:
+ def __init__(self, path: str):
+ self.no_style = PromptStyle("None", "", "")
+ self.styles = {"None": self.no_style}
- if isinstance(processing.negative_prompt, list):
- processing.negative_prompt = [merge_prompts(style.negative_prompt, p) for p in processing.negative_prompt]
- else:
- processing.negative_prompt = merge_prompts(style.negative_prompt, processing.negative_prompt)
-
-
-def save_styles(path: str, styles: abc.Iterable[PromptStyle]) -> None:
- # Write to temporary file first, so we don't nuke the file if something goes wrong
- fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
- writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for style in styles)
-
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.move(path, path + ".bak")
- shutil.move(temp_path, path)
+ if not os.path.exists(path):
+ return
+
+ with open(path, "r", encoding="utf8", newline='') as file:
+ reader = csv.DictReader(file)
+ for row in reader:
+ # Support loading old CSV format with "name, text"-columns
+ prompt = row["prompt"] if "prompt" in row else row["text"]
+ negative_prompt = row.get("negative_prompt", "")
+ self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+
+ def apply_styles_to_prompt(self, prompt, styles):
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
+
+ def apply_negative_styles_to_prompt(self, prompt, styles):
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
+
+ def apply_styles(self, p: StableDiffusionProcessing) -> None:
+ if isinstance(p.prompt, list):
+ p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
+ else:
+ p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
+
+ if isinstance(p.negative_prompt, list):
+ p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
+ else:
+ p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
+
+ def save_styles(self, path: str) -> None:
+ # Write to temporary file first, so we don't nuke the file if something goes wrong
+ fd, temp_path = tempfile.mkstemp(".csv")
+ with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
+ # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
+ # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+ writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
+ writer.writeheader()
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+
+ # Always keep a backup file around
+ if os.path.exists(path):
+ shutil.move(path, path + ".bak")
+ shutil.move(temp_path, path)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index d60febfc..30d89849 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,13 +6,13 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
- prompt_style=prompt_style,
+ styles=[prompt_style, prompt_style2],
negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
diff --git a/modules/ui.py b/modules/ui.py
index c88a1f22..efd57b2e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -237,13 +237,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
return [gr_show(), gr_show()]
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles[style.name] = style
+ shared.prompt_styles.styles[style.name] = style
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
# reserialize all styles every time we save them
- modules.styles.save_styles(shared.styles_filename, shared.prompt_styles.values())
+ shared.prompt_styles.save_styles(shared.styles_filename)
- update = {"visible": True, "choices": list(shared.prompt_styles), "__type__": "update"}
- return [update, update]
+ update = {"visible": True, "choices": list(shared.prompt_styles.styles), "__type__": "update"}
+ return [update, update, update, update]
+
+
+def apply_styles(prompt, prompt_neg, style1_name, style2_name):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
def interrogate(image):
@@ -251,15 +258,46 @@ def interrogate(image):
return gr_show(True) if prompt is None else prompt
+
+def create_toprow(is_img2img):
+ with gr.Row(elem_id="toprow"):
+ with gr.Column(scale=4):
+ with gr.Row():
+ with gr.Column(scale=8):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
+ roll = gr.Button('Roll', elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id="style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+
+ with gr.Row():
+ with gr.Column(scale=8):
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id="style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+
+ with gr.Column(scale=1):
+ with gr.Row():
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
+
+ with gr.Row():
+ if is_img2img:
+ interrogate = gr.Button('Interrogate', elem_id="interrogate")
+ else:
+ interrogate = None
+ prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
+ save_style = gr.Button('Create style', elem_id="style_create")
+
+ check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, check_progress
+
+
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- with gr.Row(elem_id="toprow"):
- txt2img_prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
- txt2img_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
- txt2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
- roll = gr.Button('Roll', elem_id="txt2img_roll", visible=len(shared.artist_db.artists) > 0)
- submit = gr.Button('Generate', elem_id="txt2img_generate", variant='primary')
- check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, check_progress = create_toprow(is_img2img=False)
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -290,7 +328,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery').style(grid=4)
-
with gr.Group():
with gr.Row():
save = gr.Button('Save')
@@ -298,7 +335,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras')
interrupt = gr.Button('Interrupt')
- txt2img_save_style = gr.Button('Save prompt as style')
progressbar = gr.HTML(elem_id="progressbar")
@@ -306,7 +342,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
html_info = gr.HTML()
generation_info = gr.Textbox(visible=False)
-
txt2img_args = dict(
fn=txt2img,
_js="submit",
@@ -314,6 +349,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
txt2img_prompt,
txt2img_negative_prompt,
txt2img_prompt_style,
+ txt2img_prompt_style2,
steps,
sampler_index,
restore_faces,
@@ -343,7 +379,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[progressbar, txt2img_preview, txt2img_preview],
)
-
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
@@ -376,13 +411,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- with gr.Row(elem_id="toprow"):
- img2img_prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
- img2img_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="img2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
- img2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
- img2img_interrogate = gr.Button('Interrogate', elem_id="img2img_interrogate", variant='primary')
- submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary')
- check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, check_progress = create_toprow(is_img2img=True)
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -511,6 +540,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
img2img_prompt,
img2img_negative_prompt,
img2img_prompt_style,
+ img2img_prompt_style2,
init_img,
init_img_with_mask,
init_mask,
@@ -580,15 +610,35 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
]
)
+ roll.click(
+ fn=roll_artist,
+ inputs=[
+ img2img_prompt,
+ ],
+ outputs=[
+ img2img_prompt,
+ ]
+ )
+
+ prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
+ style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+
dummy_component = gr.Label(visible=False)
- for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]):
+ for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click(
fn=add_style,
_js="ask_for_style_name",
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_style, img2img_prompt_style],
+ outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
+ )
+
+ for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
+ button.click(
+ fn=apply_styles,
+ inputs=[prompt, negative_prompt, style1, style2],
+ outputs=[prompt, negative_prompt, style1, style2],
)
with gr.Blocks(analytics_enabled=False) as extras_interface: