aboutsummaryrefslogtreecommitdiff
path: root/modules/ui.py
diff options
context:
space:
mode:
authorliamkerr <liamkerr@users.noreply.github.com>2022-10-02 10:02:38 -0400
committerGitHub <noreply@github.com>2022-10-02 10:02:38 -0400
commit7308aeefd9d25212398068c1a35f0b879b9ba06f (patch)
tree60f81d876522c772f85911ec4065ab3f448a8479 /modules/ui.py
parent3c6a049fc3c6b54ada3736710a7e86663ea7f3d9 (diff)
parent4e72a1aab6d1b3a8d8c09fadc81843a07c05cc18 (diff)
Merge branch 'master' into token_updates
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py139
1 files changed, 122 insertions, 17 deletions
diff --git a/modules/ui.py b/modules/ui.py
index 40c08984..491a9dac 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -22,6 +22,7 @@ import gradio as gr
import gradio.utils
import gradio.routes
+from modules import sd_hijack
from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -34,6 +35,7 @@ import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
+import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
@@ -144,8 +146,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-def wrap_gradio_call(func):
- def f(*args, **kwargs):
+def wrap_gradio_call(func, extra_outputs=None):
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
@@ -161,7 +163,10 @@ def wrap_gradio_call(func):
shared.state.job = ""
shared.state.job_count = 0
- res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+
+ res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
@@ -181,6 +186,7 @@ def wrap_gradio_call(func):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False
+ shared.state.job_count = 0
return tuple(res)
@@ -189,7 +195,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part):
if shared.state.job_count == 0:
- return "", gr_show(False), gr_show(False)
+ return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0
@@ -221,13 +227,19 @@ def check_progress_call(id_part):
else:
preview_visibility = gr_show(True)
- return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part):
shared.state.job_count = -1
shared.state.current_latent = None
shared.state.current_image = None
+ shared.state.textinfo = None
return check_progress_call(id_part)
@@ -403,13 +415,16 @@ def create_toprow(is_img2img):
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
-def setup_progressbar(progressbar, preview, id_part):
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click(
fn=lambda: check_progress_call(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
@@ -417,11 +432,14 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
-def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
@@ -487,7 +505,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=txt2img,
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
txt2img_prompt,
@@ -681,7 +699,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
)
img2img_args = dict(
- fn=img2img,
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -838,7 +856,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
- fn=run_extras,
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
@@ -888,7 +906,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change(
- fn=wrap_gradio_call(run_pnginfo),
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
@@ -897,7 +915,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
-
+
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
@@ -906,10 +924,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
-
+
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ with gr.Blocks() as textual_inversion_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column():
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
+
+ new_embedding_name = gr.Textbox(label="Name")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding = gr.Button(value="Create", variant='primary')
+
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
+ train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ steps = gr.Number(label='Max steps', value=100000, precision=0)
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ gr.HTML(value="")
+
+ with gr.Column():
+ with gr.Row():
+ interrupt_training = gr.Button(value="Interrupt")
+ train_embedding = gr.Button(value="Train", variant='primary')
+
+ with gr.Column():
+ progressbar = gr.HTML(elem_id="ti_progressbar")
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
+
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_preview = gr.Image(elem_id='ti_preview', visible=False)
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
+ setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
+
+ create_embedding.click(
+ fn=modules.textual_inversion.ui.create_embedding,
+ inputs=[
+ new_embedding_name,
+ nvpt,
+ ],
+ outputs=[
+ train_embedding_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_embedding.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_embedding_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ interrupt_training.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1021,6 +1125,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"),
]
@@ -1054,11 +1159,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args):
try:
- results = run_modelmerger(*args)
+ results = modules.extras.run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- modules.sd_models.list_models() #To remove the potentially missing models from the list
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results