aboutsummaryrefslogtreecommitdiff
path: root/modules/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py127
1 files changed, 84 insertions, 43 deletions
diff --git a/modules/ui.py b/modules/ui.py
index d4b32c05..cd118552 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -22,9 +22,14 @@ import piexif
import torch
from PIL import Image, PngImagePlugin
-from modules import localization, sd_hijack, sd_models
+import gradio as gr
+import gradio.utils
+import gradio.routes
+
+from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
-from modules.shared import cmd_opts, opts, restricted_opts
+
+from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
@@ -43,6 +48,11 @@ from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
+import modules.textual_inversion.ui
+import modules.hypernetworks.ui
+
+import modules.images_history as img_his
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -593,27 +603,29 @@ def apply_setting(key, value):
return value
-def create_ui(wrap_gradio_gpu_call):
- import modules.img2img
- import modules.txt2img
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
- def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
- for k, v in args.items():
- setattr(refresh_component, k, v)
+ return gr.update(**(args or {}))
- return gr.update(**(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn = refresh,
- inputs = [],
- outputs = [refresh_component]
- )
- return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -711,6 +723,7 @@ def create_ui(wrap_gradio_gpu_call):
firstphase_width,
firstphase_height,
] + custom_inputs,
+
outputs=[
txt2img_gallery,
generation_info,
@@ -787,6 +800,7 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
txt2img_preview_params = [
@@ -854,8 +868,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -1052,6 +1066,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1174,12 +1189,12 @@ def create_ui(wrap_gradio_gpu_call):
)
#images history
images_history_switch_dict = {
- "fn":modules.generation_parameters_copypaste.connect_paste,
- "t2i":txt2img_paste_fields,
- "i2i":img2img_paste_fields
+ "fn": modules.generation_parameters_copypaste.connect_paste,
+ "t2i": txt2img_paste_fields,
+ "i2i": img2img_paste_fields
}
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
+ images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
@@ -1212,6 +1227,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
@@ -1227,6 +1243,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
@@ -1240,13 +1258,18 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
+ process_split = gr.Checkbox(label='Split oversized images')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1254,15 +1277,24 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
with gr.Tab(label="Train"):
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
+
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1296,6 +1328,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1309,6 +1342,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_activation_func,
new_hypernetwork_add_layer_norm,
@@ -1329,10 +1363,13 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
- process_caption_deepbooru
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
],
outputs=[
ti_output,
@@ -1345,7 +1382,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1370,7 +1407,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
- learn_rate,
+ hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1491,10 +1528,10 @@ Requested path was: {f}
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
+ oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
- oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
@@ -1600,19 +1637,24 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
+ (images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
- (settings_interface, "Settings", "settings"),
]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ interfaces += script_callbacks.ui_tabs_callback()
+
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
@@ -1835,9 +1877,9 @@ def load_javascript(raw_response):
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
- jsdir = os.path.join(script_path, "javascript")
- for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None:
@@ -1855,6 +1897,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript,
- gradio.routes.templates.TemplateResponse)
+reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()