aboutsummaryrefslogtreecommitdiff
path: root/modules/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py34
1 files changed, 16 insertions, 18 deletions
diff --git a/modules/ui.py b/modules/ui.py
index aba13926..de2b5544 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -157,22 +157,6 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-def save_pil_to_file(pil_image, dir=None):
- use_metadata = False
- metadata = PngImagePlugin.PngInfo()
- for key, value in pil_image.info.items():
- if isinstance(key, str) and isinstance(value, str):
- metadata.add_text(key, value)
- use_metadata = True
-
- file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
- pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
-
-
-# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
-
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
@@ -1210,7 +1194,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
- new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
@@ -1256,7 +1240,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt")
- run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change(
fn=lambda show: gr_show(show),
@@ -1283,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1293,6 +1278,11 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
+ with gr.Row():
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
+ with gr.Row():
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
@@ -1381,11 +1371,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name,
embedding_learn_rate,
batch_size,
+ gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,
@@ -1406,11 +1400,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name,
hypernetwork_learn_rate,
batch_size,
+ gradient_step,
dataset_directory,
log_directory,
training_width,
training_height,
steps,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
create_image_every,
save_embedding_every,
template_file,