From 8b74b9aa9a20e4c5c1f72641f8b9617479eb276b Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 19 Oct 2022 19:06:14 -0500 Subject: add symbol for clear button and simplify roll_col css selector --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..9f6edc5f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,6 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 +trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑 def plaintext_to_html(text): @@ -498,6 +499,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.1 From c6345bd445463b7aa41723d6637e80dfa293a890 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 19 Oct 2022 21:23:57 -0500 Subject: nerf line length --- modules/ui.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9f6edc5f..cb9a6c6e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,7 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑 +trash_prompt_symbol = '\U0001F5D1' # def plaintext_to_html(text): @@ -617,7 +617,10 @@ def create_ui(wrap_gradio_gpu_call): return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ + txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ + token_button = create_toprow(is_img2img=False) + dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.1 From 158d678f596d7fc304a6ce2f0dc31f8abfe62250 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 01:08:24 -0500 Subject: clear prompt button now works on both relevant tabs. Device detection stuff will be added later. --- modules/ui.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cb9a6c6e..bde546cc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -424,6 +424,16 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox +# setup button for clearing prompt input boxes on client side of webui +def connect_trash_prompt(dummy_component, button, is_img2img): + + button.click( + fn=lambda: print("Clearing prompt"), + _js="trash_prompt", + inputs=[], + outputs=[], + ) + def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength @@ -540,7 +550,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -619,10 +629,11 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button = create_toprow(is_img2img=False) + token_button, trash_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + connect_trash_prompt(dummy_component, trash_prompt_button, False) with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -807,7 +818,11 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ + img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ + token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + + connect_trash_prompt(dummy_component,trash_prompt_button, True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.1 From a3b047b7c74dc6ca07f40aee778997fc1889d72f Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 19:28:58 -0500 Subject: add settings option to toggle button visibility --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index faede821..7e9c2696 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -300,6 +300,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index bde546cc..13c0b4ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -509,7 +509,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.1 From 9ba372de90df81c4f1e992d8b33ae17c6630de95 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 13:55:42 -0500 Subject: initial work on getting prompts cleared on the backend and synchronizing token counter --- modules/ui.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d2cb528e..2748a2e0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,15 +429,16 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -# setup button for clearing prompt input boxes on client side of webui -def connect_trash_prompt(dummy_component, button, is_img2img): +def clear_prompt(prompt): + print(f"type: '{prompt}'") + print(prompt) + + # update_token_counter(prompt, steps) + return "" + +def connect_trash_prompt(prompt, confirmed): + if(confirmed): clear_prompt(prompt) - button.click( - fn=lambda: print("Clearing prompt"), - _js="trash_prompt", - inputs=[], - outputs=[], - ) def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used @@ -640,7 +641,16 @@ def create_ui(wrap_gradio_gpu_call): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - connect_trash_prompt(dummy_component, trash_prompt_button, False) + + + trash_prompt_button.click( + # fn=lambda: print("Clearing prompt"), + _js="trash_prompt", + fn=lambda: clear_prompt(), + inputs=[txt2img_prompt, dummy_component], + outputs=[txt2img_prompt, dummy_component], + ) + with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -848,7 +858,6 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) - connect_trash_prompt(dummy_component,trash_prompt_button, True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.1 From ee0505dd0092ae7073b77aba93a858bda000dc60 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 14:24:14 -0500 Subject: only delete prompt on back end and remove client-side deletion --- modules/ui.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 2748a2e0..90c338da 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,15 +429,21 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(prompt): - print(f"type: '{prompt}'") - print(prompt) - # update_token_counter(prompt, steps) - return "" +def connect_trash_prompt(_, confirmed): + if(confirmed): + # update_token_counter(prompt, steps) + return ["", confirmed] -def connect_trash_prompt(prompt, confirmed): - if(confirmed): clear_prompt(prompt) +def trash_prompt_click(button, prompt): + dummy_component = gradio.Label(visible=False) + + button.click( + _js="trash_prompt", + fn=connect_trash_prompt, + inputs=[prompt, dummy_component], + outputs=[prompt, dummy_component], + ) def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): @@ -643,13 +649,7 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - trash_prompt_button.click( - # fn=lambda: print("Clearing prompt"), - _js="trash_prompt", - fn=lambda: clear_prompt(), - inputs=[txt2img_prompt, dummy_component], - outputs=[txt2img_prompt, dummy_component], - ) + trash_prompt_click(trash_prompt_button, txt2img_prompt) with gr.Row(elem_id='txt2img_progress_row'): @@ -858,6 +858,7 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + trash_prompt_click(trash_prompt_button, img2img_prompt) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.1 From de70ddaf58fae98c561738a54f574abfa14cd8d1 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:00:35 -0500 Subject: update token counter when clearing prompt --- modules/ui.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 90c338da..d3a89bf7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -430,19 +430,16 @@ def create_seed_inputs(): -def connect_trash_prompt(_, confirmed): +def connect_trash_prompt(_prompt, confirmed, _token_counter): if(confirmed): - # update_token_counter(prompt, steps) - return ["", confirmed] - -def trash_prompt_click(button, prompt): - dummy_component = gradio.Label(visible=False) + return ["", confirmed, update_token_counter("", 1)] +def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): button.click( _js="trash_prompt", fn=connect_trash_prompt, - inputs=[prompt, dummy_component], - outputs=[prompt, dummy_component], + inputs=[prompt, _dummy_confirmed, token_counter], + outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -649,7 +646,6 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - trash_prompt_click(trash_prompt_button, txt2img_prompt) with gr.Row(elem_id='txt2img_progress_row'): @@ -720,6 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -858,7 +855,6 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) - trash_prompt_click(trash_prompt_button, img2img_prompt) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -958,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.1 From 9e40520f00d836cfa93187f7f1e81e2a7bd100b9 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:13:12 -0500 Subject: refactor internal terminology to use 'clear' instead of 'trash' like #2728 --- modules/shared.py | 2 +- modules/ui.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 1585d532..ab5a0e9a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -317,7 +317,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), + "clear_prompt_visible": OptionInfo(True, "Show clear prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index d3a89bf7..31150800 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -430,14 +430,14 @@ def create_seed_inputs(): -def connect_trash_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, confirmed, _token_counter): if(confirmed): return ["", confirmed, update_token_counter("", 1)] -def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): button.click( - _js="trash_prompt", - fn=connect_trash_prompt, + _js="clear_prompt", + fn=clear_prompt, inputs=[prompt, _dummy_confirmed, token_counter], outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -518,7 +518,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -559,7 +559,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, clear_prompt_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -640,7 +640,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button, trash_prompt_button = create_toprow(is_img2img=False) + token_button, clear_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -716,7 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -853,7 +853,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ - token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): @@ -954,7 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.1 From 0c7cf08b3d27a61bab4cd8b16f8be8ae74879423 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:32:26 -0500 Subject: some doc and formatting --- modules/ui.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 31150800..b26cf10a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -429,12 +429,14 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - def clear_prompt(_prompt, confirmed, _token_counter): - if(confirmed): - return ["", confirmed, update_token_counter("", 1)] + """Given confirmation from a user on the client-side, go ahead with clearing prompt""" + if confirmed: + return ["", confirmed, update_token_counter("", 1)] + def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=clear_prompt, @@ -518,7 +520,12 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) + + clear_prompt_button = gr.Button( + value=clear_prompt_symbol, + elem_id="clear_prompt", + visible=opts.clear_prompt_visible + ) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.1 From 700340448baa7412c7cc5ff3d1349ac79ee8ed0c Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 17:24:04 -0500 Subject: forgot to clear neg prompt after moving to back. Add tooltip to hints --- modules/ui.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b26cf10a..25aeba3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,19 +429,19 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter): """Given confirmation from a user on the client-side, go ahead with clearing prompt""" if confirmed: - return ["", confirmed, update_token_counter("", 1)] + return ["", "", confirmed, update_token_counter("", 1)] -def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=clear_prompt, - inputs=[prompt, _dummy_confirmed, token_counter], - outputs=[prompt, _dummy_confirmed, token_counter], + inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], + outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], ) @@ -723,7 +723,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, txt2img_negative_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, img2img_negative_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.1 From ce42879438bf2dbd76b5b346be656292e42ffb2b Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 22 Oct 2022 14:53:37 -0500 Subject: fix js func signature and not forget to initialize confirmation var to prevent exception upon cancelling confirmation --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 25aeba3b..e58f040e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,10 +429,12 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter): +def clear_prompt(prompt, _prompt_neg, confirmed, _token_counter): """Given confirmation from a user on the client-side, go ahead with clearing prompt""" if confirmed: return ["", "", confirmed, update_token_counter("", 1)] + else: + return [prompt, _prompt_neg, confirmed, _token_counter] def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): -- cgit v1.2.1 From 2f4c91894d4c0a055c1069b2fda0e4da8fcda188 Mon Sep 17 00:00:00 2001 From: guaneec Date: Wed, 26 Oct 2022 12:10:30 +0800 Subject: Remove activation from final layer of HNs --- modules/hypernetworks/hypernetwork.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d647ea55..54346b64 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -41,8 +41,8 @@ class HypernetworkModule(torch.nn.Module): # Add a fully-connected layer linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - # Add an activation func - if activation_func == "linear" or activation_func is None: + # Add an activation func except last layer + if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 3: pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) @@ -53,7 +53,7 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout expect last layer + # Add dropout except last layer if use_dropout and i < len(layer_structure) - 3: linears.append(torch.nn.Dropout(p=0.3)) -- cgit v1.2.1 From c702d4d0df21790199d199818f25c449213ffe0f Mon Sep 17 00:00:00 2001 From: guaneec Date: Wed, 26 Oct 2022 13:43:04 +0800 Subject: Fix off-by-one --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 54346b64..3ce85bb5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -42,7 +42,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 3: + if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2: pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) @@ -54,7 +54,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) # Add dropout except last layer - if use_dropout and i < len(layer_structure) - 3: + if use_dropout and i < len(layer_structure) - 2: linears.append(torch.nn.Dropout(p=0.3)) self.linear = torch.nn.Sequential(*linears) -- cgit v1.2.1 From 877d94f97ca5491d8779440769b191e0dcd32c8e Mon Sep 17 00:00:00 2001 From: guaneec Date: Wed, 26 Oct 2022 14:50:58 +0800 Subject: Back compatibility --- modules/hypernetworks/hypernetwork.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3ce85bb5..dd317085 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -28,7 +28,7 @@ class HypernetworkModule(torch.nn.Module): "swish": torch.nn.Hardswish, } - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -42,7 +42,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2: + if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) @@ -105,7 +105,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False): self.filename = None self.name = name self.layers = {} @@ -116,11 +116,12 @@ class Hypernetwork: self.activation_func = activation_func self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout + self.activate_output = activate_output for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), ) def weights(self): @@ -147,6 +148,7 @@ class Hypernetwork: state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + state_dict['activate_output'] = self.activate_output torch.save(state_dict, filename) @@ -161,12 +163,13 @@ class Hypernetwork: self.activation_func = state_dict.get('activation_func', None) self.add_layer_norm = state_dict.get('is_layer_norm', False) self.use_dropout = state_dict.get('use_dropout', False) + self.activate_output = state_dict.get('activate_output', True) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), ) self.name = state_dict.get('name', self.name) -- cgit v1.2.1 From 91bb35b1e6842b30ce7553009c8ecea3643de8d2 Mon Sep 17 00:00:00 2001 From: guaneec Date: Wed, 26 Oct 2022 15:00:03 +0800 Subject: Merge fix --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index eab8b32f..bd171793 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -190,7 +190,7 @@ class Hypernetwork: print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) print(f"Layer norm is set to {self.add_layer_norm}") - self.use_dropout = state_dict.get('use_dropout', False + self.use_dropout = state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) -- cgit v1.2.1 From b6a8bb123bd519736306417399f6441e504f1e8b Mon Sep 17 00:00:00 2001 From: guaneec Date: Wed, 26 Oct 2022 15:15:19 +0800 Subject: Fix merge --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index bd171793..2997cead 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -60,7 +60,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) # Add dropout except last layer - if use_dropout and i < len(layer_structure) - 2: + if use_dropout and i < len(layer_structure) - 3: linears.append(torch.nn.Dropout(p=0.3)) self.linear = torch.nn.Sequential(*linears) @@ -126,7 +126,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False) + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False): self.filename = None self.name = name self.layers = {} -- cgit v1.2.1 From 7bd8581e461623932ffbd5762ee931ee51f798db Mon Sep 17 00:00:00 2001 From: Sihan Wang <31711261+shwang95@users.noreply.github.com> Date: Wed, 26 Oct 2022 20:32:55 +0800 Subject: Fix error caused by EXIF transpose when using custom scripts Some custom scripts read image directly and no need to select image in UI, this will cause error. --- modules/img2img.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 9c0cf23e..86a19f37 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -80,7 +80,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro mask = None # Use the EXIF orientation of photos taken by smartphones. - image = ImageOps.exif_transpose(image) + if image is not None: + image = ImageOps.exif_transpose(image) assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' -- cgit v1.2.1 From 85fcccc105aa50f1d78de559233eaa9f384608b5 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 26 Oct 2022 22:24:33 +0900 Subject: Squashed commit of fixing dropout silently fix dropouts for future hypernetworks add kwargs for Hypernetwork class hypernet UI for gradio input add recommended options remove as options revert adding options in ui --- modules/hypernetworks/hypernetwork.py | 25 +++++++++++++++++-------- modules/ui.py | 4 ++-- 2 files changed, 19 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 2997cead..dd921153 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module): } activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', + add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -60,7 +61,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) # Add dropout except last layer - if use_dropout and i < len(layer_structure) - 3: + if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2: linears.append(torch.nn.Dropout(p=0.3)) self.linear = torch.nn.Sequential(*linears) @@ -126,7 +127,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs): self.filename = None self.name = name self.layers = {} @@ -139,11 +140,14 @@ class Hypernetwork: self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout self.activate_output = activate_output + self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), ) def weights(self): @@ -172,7 +176,8 @@ class Hypernetwork: state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['activate_output'] = self.activate_output - + state_dict['last_layer_dropout'] = self.last_layer_dropout + torch.save(state_dict, filename) def load(self, filename): @@ -193,12 +198,16 @@ class Hypernetwork: self.use_dropout = state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) + print(f"Activate last layer is set to {self.activate_output}") + self.last_layer_dropout = state_dict.get('last_layer_dropout', False) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), ) self.name = state_dict.get('name', self.name) diff --git a/modules/ui.py b/modules/ui.py index 0a63e357..55cbe859 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1238,8 +1238,8 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Normal is default, for experiments, relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") -- cgit v1.2.1 From cc56df996e95c2c82295ab7b9928da2544791220 Mon Sep 17 00:00:00 2001 From: guaneec Date: Wed, 26 Oct 2022 23:51:51 +0800 Subject: Fix dropout logic --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index dd921153..b17598fe 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -35,7 +35,7 @@ class HypernetworkModule(torch.nn.Module): activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs): + add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -61,7 +61,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) # Add dropout except last layer - if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2: + if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): linears.append(torch.nn.Dropout(p=0.3)) self.linear = torch.nn.Sequential(*linears) -- cgit v1.2.1 From 029d7c75436558f1e884bb127caed73caaecb83a Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Thu, 27 Oct 2022 14:44:53 +0900 Subject: Revert unresolved changes in Bias initialization it should be zeros_ or parameterized in future properly. --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b17598fe..25427a37 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -75,7 +75,7 @@ class HypernetworkModule(torch.nn.Module): w, b = layer.weight.data, layer.bias.data if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: normal_(w, mean=0.0, std=0.01) - normal_(b, mean=0.0, std=0.005) + normal_(b, mean=0.0, std=0) elif weight_init == 'XavierUniform': xavier_uniform_(w) zeros_(b) -- cgit v1.2.1 From bdc90837987ed8919dd611fd01553b0c170ded5c Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Thu, 27 Oct 2022 15:20:15 -0400 Subject: Add a barebones interrogate API --- modules/api/api.py | 25 ++++++++++++++++++++++++- modules/api/models.py | 13 ++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 6e9d6097..eabdb7b8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,4 +1,4 @@ -from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo @@ -25,6 +25,11 @@ class ImageToImageResponse(BaseModel): parameters: Json info: Json +class InterrogateResponse(BaseModel): + caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") + parameters: Json + info: Json + class Api: def __init__(self, app, queue_lock): @@ -33,6 +38,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) def __base64_to_image(self, base64_string): # if has a comma, deal with prefix @@ -118,6 +124,23 @@ class Api: return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) + def interrogateapi(self, interrogatereq: InterrogateAPI): + image_b64 = interrogatereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + populate = interrogatereq.copy(update={ # Override __init__ params + } + ) + + img = self.__base64_to_image(image_b64) + + # Override object param + with self.queue_lock: + processed = shared.interrogator.interrogate(img) + + return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None) + def extrasapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 079e33d9..8be64749 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -63,7 +63,12 @@ class PydanticModelGenerator: self._model_name = model_name - self._class_data = merge_class_params(class_instance) + + if class_instance is not None: + self._class_data = merge_class_params(class_instance) + else: + self._class_data = {} + self._model_def = [ ModelDef( field=underscore(k), @@ -105,4 +110,10 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] +).generate_model() + +InterrogateAPI = PydanticModelGenerator( + "Interrogate", + None, + [{"key": "image", "type": str, "default": None}] ).generate_model() \ No newline at end of file -- cgit v1.2.1 From 44ab954fabb9c1273366ebdca47f8da394d61aab Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Sat, 29 Oct 2022 10:02:56 -0700 Subject: Fix latent upscale highres fix #3888 --- modules/processing.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 548eec29..f18b7db2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -653,6 +653,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if opts.use_scale_latent_for_hires_fix: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + image_conditioning = self.txt2img_image_conditioning(samples) else: decoded_samples = decode_first_stage(self.sd_model, samples) @@ -674,6 +675,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + image_conditioning = self.img2img_image_conditioning( + decoded_samples, + samples, + decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3]) + ) + shared.state.nextjob() self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -684,11 +691,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - image_conditioning = self.img2img_image_conditioning( - decoded_samples, - samples, - decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3]) - ) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) return samples -- cgit v1.2.1 From 6e2ce4e735db64afcd0fe637327ca4ec78335706 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Sat, 29 Oct 2022 10:35:51 -0700 Subject: Added image conditioning to latent upscale. Only comuted if the mask weight is not 1.0 to avoid extra memory. Also includes some code cleanup. --- modules/processing.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f18b7db2..ee0e9e34 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -134,11 +134,7 @@ class StableDiffusionProcessing(): # Dummy zero conditioning if we're not using inpainting model. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - return torch.zeros( - x.shape[0], 5, 1, 1, - dtype=x.dtype, - device=x.device - ) + return x.new_zeros(x.shape[0], 5, 1, 1) height = height or self.height width = width or self.width @@ -156,11 +152,7 @@ class StableDiffusionProcessing(): def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): if self.sampler.conditioning_key not in {'hybrid', 'concat'}: # Dummy zero conditioning if we're not using inpainting model. - return torch.zeros( - latent_image.shape[0], 5, 1, 1, - dtype=latent_image.dtype, - device=latent_image.device - ) + return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) # Handle the different mask inputs if image_mask is not None: @@ -174,11 +166,10 @@ class StableDiffusionProcessing(): # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 conditioning_mask = torch.round(conditioning_mask) else: - conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:]) + conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. - conditioning_mask = conditioning_mask.to(source_image.device) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), @@ -653,7 +644,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if opts.use_scale_latent_for_hires_fix: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - image_conditioning = self.txt2img_image_conditioning(samples) + + # Avoid making the inpainting conditioning unless necessary as + # this does need some extra compute to decode / encode the image again. + if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: + image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) + else: + image_conditioning = self.txt2img_image_conditioning(samples) else: decoded_samples = decode_first_stage(self.sd_model, samples) @@ -675,11 +672,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) - image_conditioning = self.img2img_image_conditioning( - decoded_samples, - samples, - decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3]) - ) + image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) shared.state.nextjob() -- cgit v1.2.1 From 39f55c3c35873bc7dd9792cb2155746a1c3d4292 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Sat, 29 Oct 2022 14:13:02 -0700 Subject: Re-add explicit device move --- modules/processing.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index ee0e9e34..d07e3db9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -170,6 +170,7 @@ class StableDiffusionProcessing(): # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. + conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), -- cgit v1.2.1 From 71571e3f055237d71ba2d47756846ad1d73be00c Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Sun, 30 Oct 2022 00:35:40 -0700 Subject: Replaced master branch fix with updated fix. --- modules/processing.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3dd44d3a..512c484f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -688,8 +688,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - image_conditioning = self.txt2img_image_conditioning(x) - # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() -- cgit v1.2.1 From be27fd4690b1eb6c74da1e31c9696a0f1901fbba Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 17:01:01 +0800 Subject: fix broken progress api by previous rework --- modules/shared.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..2c7d28a5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import json import os import sys from collections import OrderedDict +import time import gradio as gr import tqdm @@ -132,6 +133,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + time_start = None def skip(self): self.skipped = True @@ -168,6 +170,7 @@ class State: self.skipped = False self.interrupted = False self.textinfo = None + self.time_start = time.time() devices.torch_gc() -- cgit v1.2.1 From 1a4ff2de6a835cd8cc1590bbc1a8dedb5ad37e5b Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 17:02:47 +0800 Subject: fix current image in progress api when parallel processing enabled --- modules/api/api.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 6c06d449..97497f3f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -3,10 +3,9 @@ import uvicorn from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, HTTPException import modules.shared as shared -from modules import devices from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers +from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid from modules.extras import run_extras, run_pnginfo @@ -170,6 +169,16 @@ class Api: progress = min(progress, 1) + # copy from check_progress_call of ui.py + + if shared.parallel_processing_allowed: + if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None: + if shared.opts.show_progress_grid: + shared.state.current_image = samples_to_image_grid(shared.state.current_latent) + else: + shared.state.current_image = sample_to_image(shared.state.current_latent) + shared.state.current_image_sampling_step = shared.state.sampling_step + current_image = None if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) -- cgit v1.2.1 From 4b8a192f680101de247dca79e48974b53bf961fe Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sat, 29 Oct 2022 16:36:43 +0900 Subject: add optimizer save option to shared.opts --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..065b893d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -286,6 +286,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.1 From 20194fd9752a280306fb66b57b258609b0918c46 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sat, 29 Oct 2022 16:56:42 +0900 Subject: We have duplicate linear now --- modules/hypernetworks/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index aad09ffc..c2d4b51c 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork not_available = ["hardswish", "multiheadattention"] -keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) +keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): # Remove illegal characters from name. -- cgit v1.2.1 From 9d96d7d0a0aa0a966a9aefd24342345eb65952ed Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sun, 30 Oct 2022 20:39:04 +0900 Subject: resolve conflicts --- modules/hypernetworks/hypernetwork.py | 44 ++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index a11e01d6..8f74cdea 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -21,6 +21,7 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm from collections import defaultdict, deque from statistics import stdev, mean +optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -139,6 +140,8 @@ class Hypernetwork: self.weight_init = weight_init self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout + self.optimizer_name = None + self.optimizer_state_dict = None for size in enable_sizes or []: self.layers[size] = ( @@ -171,6 +174,10 @@ class Hypernetwork: state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + if self.optimizer_name is not None: + state_dict['optimizer_name'] = self.optimizer_name + if self.optimizer_state_dict: + state_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(state_dict, filename) @@ -190,7 +197,14 @@ class Hypernetwork: self.add_layer_norm = state_dict.get('is_layer_norm', False) print(f"Layer norm is set to {self.add_layer_norm}") self.use_dropout = state_dict.get('use_dropout', False) - print(f"Dropout usage is set to {self.use_dropout}" ) + print(f"Dropout usage is set to {self.use_dropout}") + self.optimizer_name = state_dict.get('optimizer_name', 'AdamW') + print(f"Optimizer name is {self.optimizer_name}") + self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None) + if self.optimizer_state_dict: + print("Loaded existing optimizer from checkpoint") + else: + print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: @@ -392,8 +406,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True - # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... - optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + # Here we use optimizer from saved HN, or we can specify as UI option. + if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict: + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) + else: + print(f"Optimizer type {optimizer_name} is not defined!") + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' + if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. + try: + optimizer.load_state_dict(hypernetwork.optimizer_state_dict) + except RuntimeError as e: + print("Cannot resume from saved optimizer!") + print(e) steps_without_grad = 0 @@ -455,8 +480,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) - + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", "learn_rate": scheduler.learn_rate @@ -514,14 +542,18 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) - + del optimizer + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. return hypernetwork, filename + def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_hypernetwork_name = hypernetwork.name old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None -- cgit v1.2.1 From c9bb33dd43dbb9479ff1b70351df14508c89ac60 Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 30 Oct 2022 12:52:50 +0100 Subject: add resrgan 8x, allow use 1x and up to 8x extra models, move BSRGAN model, add nearest --- modules/esrgan_model.py | 17 +++++++++++++---- modules/modelloader.py | 3 +++ modules/ui.py | 2 +- modules/upscaler.py | 17 ++++++++++++++++- 4 files changed, 33 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a13cf6ac..c61669b4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -50,6 +50,7 @@ def mod2normal(state_dict): def resrgan2normal(state_dict, nb=23): # this code is copied from https://github.com/victorca25/iNNfer if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + re8x = 0 crt_net = {} items = [] for k, v in state_dict.items(): @@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23): crt_net['model.3.bias'] = state_dict['conv_up1.bias'] crt_net['model.6.weight'] = state_dict['conv_up2.weight'] crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - crt_net['model.8.weight'] = state_dict['conv_hr.weight'] - crt_net['model.8.bias'] = state_dict['conv_hr.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] + + if 'conv_up3.weight' in state_dict: + # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py + re8x = 3 + crt_net['model.9.weight'] = state_dict['conv_up3.weight'] + crt_net['model.9.bias'] = state_dict['conv_up3.bias'] + + crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] + crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] + crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] + crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net return state_dict diff --git a/modules/modelloader.py b/modules/modelloader.py index b0f2f33d..e4a6f8ac 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -85,6 +85,9 @@ def cleanup_models(): src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) + src_path = os.path.join(models_path, "BSRGAN") + dest_path = os.path.join(models_path, "ESRGAN") + move_files(src_path, dest_path, ".pth") src_path = os.path.join(root_path, "gfpgan") dest_path = os.path.join(models_path, "GFPGAN") move_files(src_path, dest_path) diff --git a/modules/ui.py b/modules/ui.py index 5055ca64..47610f5c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by'): - upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) with gr.TabItem('Scale to'): with gr.Group(): with gr.Row(): diff --git a/modules/upscaler.py b/modules/upscaler.py index 6ab2fb40..83fde7ca 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -10,6 +10,7 @@ import modules.shared from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) from modules.paths import models_path @@ -57,7 +58,7 @@ class Upscaler: dest_w = img.width * scale dest_h = img.height * scale for i in range(3): - if img.width >= dest_w and img.height >= dest_h: + if img.width > dest_w and img.height > dest_h: break img = self.do_upscale(img, selected_model) if img.width != dest_w or img.height != dest_h: @@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler): self.name = "Lanczos" self.scalers = [UpscalerData("Lanczos", None, self)] + +class UpscalerNearest(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Nearest" + self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file -- cgit v1.2.1 From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- modules/sd_models.py | 31 +++++-------- modules/sd_vae.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 8 ++-- 3 files changed, 136 insertions(+), 24 deletions(-) create mode 100644 modules/sd_vae.py (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..91ad4b5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks +from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - -def load_model_weights(model, checkpoint_info): +def load_model_weights(model, checkpoint_info, force=False): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if checkpoint_info not in checkpoints_loaded: + if force or checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path - - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict) - + sd_vae.load_vae(model, checkpoint_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: @@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None): +def load_model(checkpoint_info=None, force=False): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -234,7 +223,7 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -252,16 +241,16 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model, info=None, force=False): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info) + load_model(checkpoint_info, force=force) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py new file mode 100644 index 00000000..82764e55 --- /dev/null +++ b/modules/sd_vae.py @@ -0,0 +1,121 @@ +import torch +import os +from collections import namedtuple +from modules import shared, devices +from modules.paths import models_path +import glob + +model_dir = "Stable-diffusion" +model_path = os.path.abspath(os.path.join(models_path, model_dir)) +vae_dir = "VAE" +vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} +default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_list = ["auto", "None"] +default_vae_values = [default_vae_dict[x] for x in default_vae_list] +vae_dict = dict(default_vae_dict) +vae_list = list(default_vae_list) +first_load = True + +def get_filename(filepath): + return os.path.splitext(os.path.basename(filepath))[0] + +def refresh_vae_list(vae_path=vae_path, model_path=model_path): + global vae_dict, vae_list + res = {} + candidates = [ + *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) + ] + if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): + candidates.append(shared.cmd_opts.vae_path) + for filepath in candidates: + name = get_filename(filepath) + res[name] = filepath + vae_list.clear() + vae_list.extend(default_vae_list) + vae_list.extend(list(res.keys())) + vae_dict.clear() + vae_dict.update(default_vae_dict) + vae_dict.update(res) + return vae_list + +def load_vae(model, checkpoint_file, vae_file="auto"): + global first_load, vae_dict, vae_list + # save_settings = False + + # if vae_file argument is provided, it takes priority + if vae_file and vae_file not in default_vae_list: + if not os.path.isfile(vae_file): + vae_file = "auto" + # save_settings = True + print("VAE provided as function argument doesn't exist") + # for the first load, if vae-path is provided, it takes priority and failure is reported + if first_load and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + # save_settings = True + # print("Using VAE provided as command line argument") + else: + print("VAE provided as command line argument doesn't exist") + # else, we load from settings + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + # vae-path cmd arg takes priority for auto + if vae_file == "auto" and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + print("Using VAE provided as command line argument") + # if still not found, try look for ".vae.pt" beside model + model_path = os.path.splitext(checkpoint_file)[0] + if vae_file == "auto": + vae_file_try = model_path + ".vae.pt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # if still not found, try look for ".vae.ckpt" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.ckpt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # No more fallbacks for auto + if vae_file == "auto": + vae_file = None + # Last check, just because + if vae_file and not os.path.exists(vae_file): + vae_file = None + + if vae_file: + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + model.first_stage_model.load_state_dict(vae_dict_1) + + # If vae used is not in dict, update it + # It will be removed on refresh though + if vae_file is not None: + vae_opt = get_filename(vae_file) + if vae_opt not in vae_dict: + vae_dict[vae_opt] = vae_file + vae_list.append(vae_opt) + + """ + # Save current VAE to VAE settings, maybe? will it work? + if save_settings: + if vae_file is None: + vae_opt = "None" + + # shared.opts.sd_vae = vae_opt + """ + + first_load = False + model.first_stage_model.to(devices.dtype_vae) diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..06440ac4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization +from modules import sd_samplers, sd_models, localization, sd_vae from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -295,6 +295,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), @@ -407,11 +408,12 @@ class Options: if bad_settings > 0: print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) - def onchange(self, key, func): + def onchange(self, key, func, call=True): item = self.data_labels.get(key) item.onchange = func - func() + if call: + func() def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} -- cgit v1.2.1 From e1b2ea6e0012ecc988385fc523d8fb50ea5d6be5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 22:11:45 +0700 Subject: Change VAE search order and thus priority --- modules/sd_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 82764e55..0767b925 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -25,10 +25,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): global vae_dict, vae_list res = {} candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) -- cgit v1.2.1 From d9e4e4d7a09d4aee8ce249a3c8e91ce165b10fa5 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Sun, 30 Oct 2022 15:33:02 -0700 Subject: Fix non-square full resolution inpainting. --- modules/masking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/masking.py b/modules/masking.py index fd8d9241..a5c4d2da 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w ratio_processing = processing_width / processing_height if ratio_crop_region > ratio_processing: - desired_height = (x2 - x1) * ratio_processing + desired_height = (x2 - x1) / ratio_processing desired_height_diff = int(desired_height - (y2-y1)) y1 -= desired_height_diff//2 y2 += desired_height_diff - desired_height_diff//2 -- cgit v1.2.1 From 21fba39c609859a60616420afda3b34a89e00761 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 30 Oct 2022 23:45:52 +0000 Subject: Add callbacks and param objects --- modules/script_callbacks.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6ea58d61..a206ea59 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -24,12 +24,22 @@ class ImageSaveParams: """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" +class CGFDenoiserParams: + def __init__(self, x_in, image_cond_in, sigma_in, sampling_step, total_sampling_steps): + self.x_in = x_in + self.image_cond_in = image_cond_in + self.sigma_in = sigma_in + self.sampling_step = sampling_step + self.total_sampling_steps = total_sampling_steps + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] callbacks_before_image_saved = [] callbacks_image_saved = [] +callbacks_cfg_denoiser = [] def clear_callbacks(): @@ -84,6 +94,14 @@ def image_saved_callback(params: ImageSaveParams): report_exception(c, 'image_saved_callback') +def cfg_denoiser_callback(params: CGFDenoiserParams): + for c in callbacks_cfg_denoiser: + try: + c.callback(params) + except Exception: + report_exception(c, 'cfg_denoiser_callback') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -130,3 +148,12 @@ def on_image_saved(callback): - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ add_callback(callbacks_image_saved, callback) + + +def on_cfg_denoiser(callback): + """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. + The callback is called with one argument: + - params: CGFDenoiserParams - parameters to be passed to the inner model and sampling state details. + """ + add_callback(callbacks_cfg_denoiser, callback) + -- cgit v1.2.1 From 8906be85ac91310b37dccddc44f23631eb7a15f5 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 30 Oct 2022 23:47:08 +0000 Subject: add callback cleardown --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a206ea59..b0b8dc47 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -48,7 +48,7 @@ def clear_callbacks(): callbacks_ui_settings.clear() callbacks_before_image_saved.clear() callbacks_image_saved.clear() - + callbacks_cfg_denoiser.clear() def model_loaded_callback(sd_model): for c in callbacks_model_loaded: -- cgit v1.2.1 From 8ae0ea9deaa5a09d1e0aa8b2f8e97c38d71cdbda Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 30 Oct 2022 23:48:33 +0000 Subject: Add callback to sd_samplers --- modules/sd_samplers.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 3670b57d..30cb5c4b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -11,6 +11,7 @@ from modules import prompt_parser, devices, processing, images from modules.shared import opts, cmd_opts, state import modules.shared as shared +from modules.script_callbacks import CGFDenoiserParams, cfg_denoiser_callback SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -278,6 +279,8 @@ class CFGDenoiser(torch.nn.Module): image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + cfg_denoiser_callback(CGFDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)) + if tensor.shape[1] == uncond.shape[1]: cond_in = torch.cat([tensor, uncond]) -- cgit v1.2.1 From b96d0c4e9ecec3c856b9b4ec795dbd0d34fcac51 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 14:42:28 +0700 Subject: Fix typo from previous commit --- modules/sd_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0767b925..2ce44d5f 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -27,8 +27,8 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) -- cgit v1.2.1 From 726769da35970f4c100fa7edf11850f9dc059c41 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:19:34 +0700 Subject: Checkpoint cache by combination key of checkpoint and vae --- modules/sd_models.py | 27 ++++++++++++++++----------- modules/sd_vae.py | 8 +++++++- 2 files changed, 23 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 91ad4b5e..850f7b7b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -def load_model_weights(model, checkpoint_info, force=False): +def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if force or checkpoint_info not in checkpoints_loaded: + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + + checkpoint_key = (checkpoint_info, vae_file) + + if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, checkpoint_file) + sd_vae.load_vae(model, vae_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU else: - print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) - model.load_state_dict(checkpoints_loaded[checkpoint_info]) + vae_name = sd_vae.get_filename(vae_file) + print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + checkpoints_loaded.move_to_end(checkpoint_key) + model.load_state_dict(checkpoints_loaded[checkpoint_key]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None, force=False): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info, force=force) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 2ce44d5f..e9239326 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -43,7 +43,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_dict.update(res) return vae_list -def load_vae(model, checkpoint_file, vae_file="auto"): +def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list # save_settings = False @@ -94,6 +94,12 @@ def load_vae(model, checkpoint_file, vae_file="auto"): if vae_file and not os.path.exists(vae_file): vae_file = None + return vae_file + +def load_vae(model, vae_file): + global first_load, vae_dict, vae_list + # save_settings = False + if vae_file: print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) -- cgit v1.2.1 From 36966e3200943dbf890b5338cfa939df552d3c47 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:38:58 +0700 Subject: Fix #4035 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..a29c8c1a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -201,7 +201,7 @@ def load_model_weights(model, checkpoint_info): if shared.opts.sd_checkpoint_cache > 0: checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") -- cgit v1.2.1 From bf7a699845675eefdabb9cfa40c55398976274ae Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 16:27:27 +0700 Subject: Fix #4035 for real now --- modules/sd_models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index a29c8c1a..b2dd005a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash + if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) - - if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: - checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) model.load_state_dict(checkpoints_loaded[checkpoint_info]) + if shared.opts.sd_checkpoint_cache > 0: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -- cgit v1.2.1 From df6a7ebfe8cc4da23861e3e2583693bb7808d573 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 31 Oct 2022 11:50:33 -0400 Subject: revert things to master --- modules/api/api.py | 2 -- modules/api/models.py | 6 +----- 2 files changed, 1 insertion(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index c510a833..6a903e4c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -117,8 +117,6 @@ class Api: return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) - def extrasapi(self): - raise NotImplementedError def extras_single_image_api(self, req: ExtrasSingleImageRequest): reqDict = setUpscalers(req) diff --git a/modules/api/models.py b/modules/api/models.py index 035a7179..82ab29b8 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -64,11 +64,7 @@ class PydanticModelGenerator: self._model_name = model_name - - if class_instance is not None: - self._class_data = merge_class_params(class_instance) - else: - self._class_data = {} + self._class_data = merge_class_params(class_instance) self._model_def = [ ModelDef( -- cgit v1.2.1 From 3f3d14afd5abd07d3843370dc1c28be299dbdbab Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 31 Oct 2022 11:51:21 -0400 Subject: nix unused thing --- modules/api/api.py | 4 ---- 1 file changed, 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 6a903e4c..536e3f16 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -182,10 +182,6 @@ class Api: if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") - populate = interrogatereq.copy(update={ # Override __init__ params - } - ) - img = self.__base64_to_image(image_b64) # Override object param -- cgit v1.2.1 From 467cae167a3066ffa2b2a5e6f16dd42642219aba Mon Sep 17 00:00:00 2001 From: TinkTheBoush Date: Tue, 1 Nov 2022 23:29:12 +0900 Subject: append_tag_shuffle --- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/textual_inversion/dataset.py | 10 ++++++++-- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 3 +++ 4 files changed, 15 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index a11e01d6..7630fb81 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -331,7 +331,7 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -376,7 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index ad726577..e9d97cc1 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,7 +24,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -33,6 +33,7 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) + self.shuffle_tags = shuffle_tags self.dataset = [] @@ -98,7 +99,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", filename_text) + if self.tag_shuffle: + tags = filename_text.split(',') + random.shuffle(tags) + text = text.replace("[filewords]", ','.join(tags)) + else: + text = text.replace("[filewords]", filename_text) return text def __len__(self): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e0babb46..64700e23 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -271,7 +271,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) diff --git a/modules/ui.py b/modules/ui.py index 2c15abb7..ad383979 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1267,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + shuffle_tags = gr.Checkbox(label='Shuffleing tags by "," when create texts', value=True) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1361,6 +1362,7 @@ def create_ui(wrap_gradio_gpu_call): template_file, save_image_with_stored_embedding, preview_from_txt2img, + shuffle_tags, *txt2img_preview_params, ], outputs=[ @@ -1385,6 +1387,7 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every, template_file, preview_from_txt2img, + shuffle_tags, *txt2img_preview_params, ], outputs=[ -- cgit v1.2.1 From bc607686065b8c7751d1af7c05b960378fa256de Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Tue, 1 Nov 2022 23:26:55 +0800 Subject: Enable override_settings to take effect for hypernetworks --- modules/processing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 57d3a523..86d015af 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -422,13 +422,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible + opts.data[k] = v # we don't call onchange for simplicity which makes changing model impossible + if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not res = process_images_inner(p) - finally: + finally: # restore opts to original state for k, v in stored_opts.items(): opts.data[k] = v + if k == 'sd_hypernetwork': shared.reload_hypernetworks() return res -- cgit v1.2.1 From 401350cd59555439570ba5bc95f0ac5698e372e4 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Tue, 1 Nov 2022 14:03:56 -0500 Subject: clear on the client-side again --- modules/ui.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 447722cd..f43e79ab 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -400,19 +400,12 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(prompt, _prompt_neg, confirmed, _token_counter): - """Given confirmation from a user on the client-side, go ahead with clearing prompt""" - if confirmed: - return ["", "", confirmed, update_token_counter("", 1)] - else: - return [prompt, _prompt_neg, confirmed, _token_counter] - def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", - fn=clear_prompt, + fn=None, inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], ) -- cgit v1.2.1 From 1dd5d6bafad7575f347056a29636cbab71c1c468 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Tue, 1 Nov 2022 14:33:55 -0500 Subject: clean py func defs --- modules/ui.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index f43e79ab..8a1f3887 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -401,13 +401,13 @@ def create_seed_inputs(): -def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): +def connect_clear_prompt(button): """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=None, - inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], - outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], + inputs=[], + outputs=[], ) @@ -746,7 +746,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, txt2img_prompt, txt2img_negative_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -929,7 +929,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, img2img_prompt, img2img_negative_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.1 From 86d35526a13a0e2432ab71d1d40b191615d3e343 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Tue, 1 Nov 2022 14:53:40 -0500 Subject: make line evil again --- modules/ui.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 8a1f3887..bd67c1bd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -692,9 +692,7 @@ def create_ui(wrap_gradio_gpu_call): parameters_copypaste.reset() with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ - txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button, clear_prompt_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -850,9 +848,7 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ - img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ - token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): -- cgit v1.2.1 From cd88e21dc5d5cfdfbd408454acd259b7db9d0ec8 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 2 Nov 2022 00:34:58 +0000 Subject: Class Name typo and add descriptions to fields. --- modules/script_callbacks.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index b0b8dc47..ff40b056 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -24,13 +24,22 @@ class ImageSaveParams: """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" -class CGFDenoiserParams: - def __init__(self, x_in, image_cond_in, sigma_in, sampling_step, total_sampling_steps): - self.x_in = x_in - self.image_cond_in = image_cond_in - self.sigma_in = sigma_in +class CFGDenoiserParams: + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps): + self.x = x + """Latent image representation in the process of being denoised""" + + self.image_cond = image_cond + """Conditioning image""" + + self.sigma = sigma + """Current sigma noise step value""" + self.sampling_step = sampling_step + """Current Sampling step number""" + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) @@ -94,7 +103,7 @@ def image_saved_callback(params: ImageSaveParams): report_exception(c, 'image_saved_callback') -def cfg_denoiser_callback(params: CGFDenoiserParams): +def cfg_denoiser_callback(params: CFGDenoiserParams): for c in callbacks_cfg_denoiser: try: c.callback(params) @@ -153,7 +162,7 @@ def on_image_saved(callback): def on_cfg_denoiser(callback): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: - - params: CGFDenoiserParams - parameters to be passed to the inner model and sampling state details. + - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ add_callback(callbacks_cfg_denoiser, callback) -- cgit v1.2.1 From 5b6bedf6f2ebacb7f1f5809af8e26a6a1af16e2a Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 2 Nov 2022 00:38:17 +0000 Subject: Update class name and assign back to vars --- modules/sd_samplers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 30cb5c4b..ebc0d896 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -11,7 +11,7 @@ from modules import prompt_parser, devices, processing, images from modules.shared import opts, cmd_opts, state import modules.shared as shared -from modules.script_callbacks import CGFDenoiserParams, cfg_denoiser_callback +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -279,7 +279,11 @@ class CFGDenoiser(torch.nn.Module): image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - cfg_denoiser_callback(CGFDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: cond_in = torch.cat([tensor, uncond]) -- cgit v1.2.1 From c9148b2312b36fee8727f5233da9dbe32aa1f58c Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 1 Nov 2022 21:56:47 -0300 Subject: Release processing resources after it finishes --- modules/img2img.py | 2 ++ modules/processing.py | 7 ++++--- modules/txt2img.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 35c5df9b..fac010aa 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -137,6 +137,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if processed is None: processed = process_images(p) + p.close() + shared.total_tqdm.clear() generation_info_js = processed.js() diff --git a/modules/processing.py b/modules/processing.py index 57d3a523..b541ee2b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -202,6 +202,10 @@ class StableDiffusionProcessing(): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): raise NotImplementedError() + def close(self): + self.sd_model = None + self.sampler = None + class Processed: def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): @@ -597,9 +601,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess(p, res) - p.sd_model = None - p.sampler = None - return res diff --git a/modules/txt2img.py b/modules/txt2img.py index c9d5a090..8e4e8677 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: if processed is None: processed = process_images(p) + p.close() + shared.total_tqdm.clear() generation_info_js = processed.js() -- cgit v1.2.1 From 5510c282b1f1974005790066b5e444f74a5178fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 07:26:31 +0300 Subject: fix for extensions' javascript not loading --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 2c15abb7..a94f46ea 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -671,6 +671,8 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img + reload_javascript() + parameters_copypaste.reset() with gr.Blocks(analytics_enabled=False) as txt2img_interface: @@ -1782,4 +1784,3 @@ def load_javascript(raw_response): reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) -reload_javascript() -- cgit v1.2.1 From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 12:51:46 +0700 Subject: Reload VAE without reloading sd checkpoint --- modules/sd_models.py | 15 ++++---- modules/sd_vae.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 97 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ab85b65..883639d1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - checkpoint_key = (checkpoint_info, vae_file) + checkpoint_key = checkpoint_info if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, vae_file) - model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: + # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU + else: vae_name = sd_vae.get_filename(vae_file) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") @@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.load_vae(model, vae_file) + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack @@ -254,14 +253,14 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model=None, info=None, force=False): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e9239326..78e14e8a 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,26 +1,65 @@ import torch import os from collections import namedtuple -from modules import shared, devices +from modules import shared, devices, script_callbacks from modules.paths import models_path import glob + model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) vae_dir = "VAE" vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + default_vae_dict = {"auto": "auto", "None": "None"} default_vae_list = ["auto", "None"] + + default_vae_values = [default_vae_dict[x] for x in default_vae_list] vae_dict = dict(default_vae_dict) vae_list = list(default_vae_list) first_load = True + +base_vae = None +loaded_vae_file = None +checkpoint_info = None + + +def get_base_vae(model): + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: + return base_vae + return None + + +def store_base_vae(model): + global base_vae, checkpoint_info + if checkpoint_info != model.sd_checkpoint_info: + base_vae = model.first_stage_model.state_dict().copy() + checkpoint_info = model.sd_checkpoint_info + + +def delete_base_vae(): + global base_vae, checkpoint_info + base_vae = None + checkpoint_info = None + + +def restore_base_vae(model): + global base_vae, checkpoint_info + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + load_vae_dict(model, base_vae) + delete_base_vae() + + def get_filename(filepath): return os.path.splitext(os.path.basename(filepath))[0] + def refresh_vae_list(vae_path=vae_path, model_path=model_path): global vae_dict, vae_list res = {} @@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_dict.update(res) return vae_list + def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list # save_settings = False @@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"): return vae_file -def load_vae(model, vae_file): - global first_load, vae_dict, vae_list + +def load_vae(model, vae_file=None): + global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False if vae_file: print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict_1) + load_vae_dict(model, vae_dict_1) - # If vae used is not in dict, update it - # It will be removed on refresh though - if vae_file is not None: + # If vae used is not in dict, update it + # It will be removed on refresh though vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) + loaded_vae_file = vae_file + """ # Save current VAE to VAE settings, maybe? will it work? if save_settings: @@ -124,4 +166,45 @@ def load_vae(model, vae_file): """ first_load = False + + +# don't call this from outside +def load_vae_dict(model, vae_dict_1=None): + if vae_dict_1: + store_base_vae(model) + model.first_stage_model.load_state_dict(vae_dict_1) + else: + restore_base_vae() model.first_stage_model.to(devices.dtype_vae) + + +def reload_vae_weights(sd_model=None, vae_file="auto"): + from modules import lowvram, devices, sd_hijack + + if not sd_model: + sd_model = shared.sd_model + + checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_file = checkpoint_info.filename + vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if loaded_vae_file == vae_file: + return + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) + + sd_hijack.model_hijack.undo_hijack(sd_model) + + load_vae(sd_model, vae_file) + + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) + + print(f"VAE Weights loaded.") + return sd_model -- cgit v1.2.1 From 95c6308ccd2e075d1fb804f5b98a4f0b07b87b7d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 09:47:53 +0300 Subject: switch to gradio 3.8 --- modules/ui.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a94f46ea..45cd8c3f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1572,8 +1572,7 @@ def create_ui(wrap_gradio_gpu_call): reload_script_bodies.click( fn=reload_scripts, inputs=[], - outputs=[], - _js='function(){}' + outputs=[] ) def request_restart(): @@ -1585,7 +1584,7 @@ def create_ui(wrap_gradio_gpu_call): fn=request_restart, inputs=[], outputs=[], - _js='function(){restart_reload()}' + _js='restart_reload' ) if column is not None: -- cgit v1.2.1 From dd2108fdac2ebf943d4ac3563a49202222b88acf Mon Sep 17 00:00:00 2001 From: Maiko Tan Date: Wed, 2 Nov 2022 15:04:35 +0800 Subject: fix: should invoke callback as well in api only mode --- modules/script_callbacks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index da88635b..c28e220e 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,6 +2,7 @@ import sys import traceback from collections import namedtuple import inspect +from typing import Optional from fastapi import FastAPI from gradio import Blocks @@ -62,7 +63,7 @@ def clear_callbacks(): callbacks_image_saved.clear() callbacks_cfg_denoiser.clear() -def app_started_callback(demo: Blocks, app: FastAPI): +def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callbacks_app_started: try: c.callback(demo, app) -- cgit v1.2.1 From a5409a6e4bc3eaa9757a7505d4564ad8e0d899ea Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 14:37:22 +0700 Subject: Save VAE provided by cmd_opts.vae_path --- modules/sd_vae.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 78e14e8a..71e7a6e6 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -78,27 +78,24 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_list.extend(default_vae_list) vae_list.extend(list(res.keys())) vae_dict.clear() - vae_dict.update(default_vae_dict) vae_dict.update(res) + vae_dict.update(default_vae_dict) return vae_list def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list - # save_settings = False - # if vae_file argument is provided, it takes priority + # if vae_file argument is provided, it takes priority, but not saved if vae_file and vae_file not in default_vae_list: if not os.path.isfile(vae_file): vae_file = "auto" - # save_settings = True print("VAE provided as function argument doesn't exist") - # for the first load, if vae-path is provided, it takes priority and failure is reported + # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported if first_load and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path - # save_settings = True - # print("Using VAE provided as command line argument") + shared.opts.data['sd_vae'] = get_filename(vae_file) else: print("VAE provided as command line argument doesn't exist") # else, we load from settings -- cgit v1.2.1 From 4a8cf01f6f7f072cc9c67d6b31662384b212dd9c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 12:12:32 +0300 Subject: remove duplicate code from #3970 --- modules/api/api.py | 10 +--------- modules/shared.py | 14 ++++++++++++++ modules/ui.py | 10 +--------- 3 files changed, 16 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index b3d85e46..71c9c160 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -178,15 +178,7 @@ class Api: progress = min(progress, 1) - # copy from check_progress_call of ui.py - - if shared.parallel_processing_allowed: - if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if shared.opts.show_progress_grid: - shared.state.current_image = samples_to_image_grid(shared.state.current_latent) - else: - shared.state.current_image = sample_to_image(shared.state.current_latent) - shared.state.current_image_sampling_step = shared.state.sampling_step + shared.state.set_current_image() current_image = None if shared.state.current_image and not req.skip_current_image: diff --git a/modules/shared.py b/modules/shared.py index 04aaa648..e65f6080 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -184,6 +184,20 @@ class State: devices.torch_gc() + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" + def set_current_image(self): + if not parallel_processing_allowed: + return + + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None: + if opts.show_progress_grid: + self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) + else: + self.current_image = sd_samplers.sample_to_image(self.current_latent) + + self.current_image_sampling_step = self.sampling_step + + state = State() artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) diff --git a/modules/ui.py b/modules/ui.py index 45cd8c3f..784439ba 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -277,15 +277,7 @@ def check_progress_call(id_part): preview_visibility = gr_show(False) if opts.show_progress_every_n_steps > 0: - if shared.parallel_processing_allowed: - - if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if opts.show_progress_grid: - shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) - else: - shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) - shared.state.current_image_sampling_step = shared.state.sampling_step - + shared.state.set_current_image() image = shared.state.current_image if image is None: -- cgit v1.2.1 From 9c67408004ed132637d10321bf44565f82055fd2 Mon Sep 17 00:00:00 2001 From: timntorres <116157310+timntorres@users.noreply.github.com> Date: Wed, 2 Nov 2022 02:18:21 -0700 Subject: Allow saving "before-highres-fix. (#4150) * Save image/s before doing highres fix. --- modules/processing.py | 17 +++++++++++++++-- modules/sd_samplers.py | 5 ++--- modules/shared.py | 1 + 3 files changed, 18 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index b541ee2b..2dcf4879 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -521,7 +521,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix. + if isinstance(p, StableDiffusionProcessingTxt2Img): + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n) + else: + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) @@ -649,7 +653,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: @@ -685,6 +689,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + # Save a copy of the image/s before doing highres fix, if applicable. + if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix: + for i in range(self.batch_size): + # This batch's ith image. + img = sd_samplers.sample_to_image(samples, i) + # Index that accounts for both batch size and batch count. + ind = i + self.batch_size*n + images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix") + shared.state.nextjob() self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 44d4c189..d7fa89a0 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -93,9 +93,8 @@ def single_sample_to_image(sample): return Image.fromarray(x_sample) -def sample_to_image(samples): - return single_sample_to_image(samples[0]) - +def sample_to_image(samples, index=0): + return single_sample_to_image(samples[index]) def samples_to_image_grid(samples): return images.image_grid([single_sample_to_image(sample) for sample in samples]) diff --git a/modules/shared.py b/modules/shared.py index e65f6080..ce991424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -255,6 +255,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), + "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), -- cgit v1.2.1 From eb5e82c7ddf5e72fa13b83bd1f12d3a07a4de1a4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 12:45:03 +0300 Subject: do not unnecessarily run VAE one more time when saving intermediate image with hires fix --- modules/processing.py | 39 ++++++++++++++++++++------------------- modules/sd_samplers.py | 1 + modules/shared.py | 2 +- 3 files changed, 22 insertions(+), 20 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 2dcf4879..3a364b5f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -199,7 +199,7 @@ class StableDiffusionProcessing(): def init(self, all_prompts, all_seeds, all_subseeds): pass - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): raise NotImplementedError() def close(self): @@ -521,11 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix. - if isinstance(p, StableDiffusionProcessingTxt2Img): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n) - else: - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) @@ -653,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: @@ -666,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images""" + def save_intermediate(image, index): + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: + return + + if not isinstance(image, Image.Image): + image = sd_samplers.sample_to_image(image, index) + + images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") + if opts.use_scale_latent_for_hires_fix: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + for i in range(samples.shape[0]): + save_intermediate(samples, i) else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -678,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) image = Image.fromarray(x_sample) + + save_intermediate(image, i) + image = images.resize_image(0, image, self.width, self.height) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) @@ -689,15 +700,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) - # Save a copy of the image/s before doing highres fix, if applicable. - if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix: - for i in range(self.batch_size): - # This batch's ith image. - img = sd_samplers.sample_to_image(samples, i) - # Index that accounts for both batch size and batch count. - ind = i + self.batch_size*n - images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix") - shared.state.nextjob() self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -844,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) - - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) @@ -856,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x devices.torch_gc() - return samples \ No newline at end of file + return samples diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d7fa89a0..c7c414ef 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -96,6 +96,7 @@ def single_sample_to_image(sample): def sample_to_image(samples, index=0): return single_sample_to_image(samples[index]) + def samples_to_image_grid(samples): return images.image_grid([single_sample_to_image(sample) for sample in samples]) diff --git a/modules/shared.py b/modules/shared.py index ce991424..01f47e38 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -256,6 +256,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), + "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), @@ -322,7 +323,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), -- cgit v1.2.1 From f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 14:41:29 +0300 Subject: fix #3986 breaking --no-half-vae --- modules/sd_models.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 883639d1..5075fadb 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.to(memory_format=torch.channels_last) if not shared.cmd_opts.no_half: + vae = model.first_stage_model + + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + model.half() + model.first_stage_model = vae devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + model.first_stage_model.to(devices.dtype_vae) + if shared.opts.sd_checkpoint_cache > 0: # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() -- cgit v1.2.1 From 3178c35224467893cf8dcedb1028c59c6c23db58 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:16:32 +0900 Subject: resolve conflicts --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 065b893d..959937d7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -285,7 +285,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), -- cgit v1.2.1 From 9b5f85ac83f864310fe19c9deab6670bad695b0d Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:18:04 +0900 Subject: first revert --- modules/shared.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 959937d7..7e8c552b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -286,7 +286,6 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.1 From 7ea5956ad5fa925f92116e8a3bf78d7f6517b654 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Wed, 2 Nov 2022 22:18:55 +0900 Subject: now add --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..7ecb40d8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,6 +309,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.1 From e21fcd72fcf147904a1df060226c4df12acf251e Mon Sep 17 00:00:00 2001 From: evshiron Date: Wed, 2 Nov 2022 22:37:45 +0800 Subject: add back png info in image api --- modules/api/api.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 71c9c160..ceaf08b0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -7,8 +7,9 @@ from fastapi import APIRouter, Depends, HTTPException import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid +from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo +from PIL import PngImagePlugin def upscaler_to_index(name: str): @@ -31,9 +32,21 @@ def setUpscalers(req: dict): def encode_pil_to_base64(image): - buffer = io.BytesIO() - image.save(buffer, format="png") - return base64.b64encode(buffer.getvalue()) + with io.BytesIO() as output_bytes: + + # Copy any text-only metadata + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + image.save( + output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) + ) + bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) class Api: -- cgit v1.2.1 From a9e979977a8e3999b01b6a086bb1332ab7ab308b Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Wed, 2 Nov 2022 19:05:01 +0300 Subject: process_one --- modules/processing.py | 3 +++ modules/scripts.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3a364b5f..72a2ee4e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -509,6 +509,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(prompts) == 0: break + if p.scripts is not None: + p.scripts.process_one(p) + with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) diff --git a/modules/scripts.py b/modules/scripts.py index 533db45c..9f82efea 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -70,6 +70,13 @@ class Script: pass + def process_one(self, p, *args): + """ + Same as process(), but called for every iteration + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -294,6 +301,15 @@ class ScriptRunner: print(f"Error running process: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def process_one(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process_one(p, *script_args) + except Exception: + print(f"Error running process_one: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def postprocess(self, p, processed): for script in self.alwayson_scripts: try: -- cgit v1.2.1 From f1b6ac64e451036fb4dfabe66d79488c56c06776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kyu=E2=99=A5?= <3ad4gum@gmail.com> Date: Wed, 2 Nov 2022 17:24:42 +0100 Subject: Added option to preview Created images on batch completion. --- modules/shared.py | 25 ++++++++++++++++--------- modules/ui.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..d4cf32a4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,9 @@ class State: self.interrupted = True def nextjob(self): + if opts.show_progress_every_n_steps == -1: + self.do_set_current_image() + self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 @@ -186,17 +189,21 @@ class State: """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: + self.do_set_current_image() + + def do_set_current_image(self): if not parallel_processing_allowed: return + if self.current_latent is None: + return + + if opts.show_progress_grid: + self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) + else: + self.current_image = sd_samplers.sample_to_image(self.current_latent) - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None: - if opts.show_progress_grid: - self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) - else: - self.current_image = sd_samplers.sample_to_image(self.current_latent) - - self.current_image_sampling_step = self.sampling_step - + self.current_image_sampling_step = self.sampling_step state = State() @@ -351,7 +358,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui.py b/modules/ui.py index 2609857e..29de1e10 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -276,7 +276,7 @@ def check_progress_call(id_part): image = gr_show(False) preview_visibility = gr_show(False) - if opts.show_progress_every_n_steps > 0: + if opts.show_progress_every_n_steps != 0: shared.state.set_current_image() image = shared.state.current_image -- cgit v1.2.1 From c07f1d0d7821f85b9ce1419992c118963d605bd7 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 2 Nov 2022 16:59:10 +0000 Subject: Convert callbacks into a private map, add utility functions for removing callbacks --- modules/script_callbacks.py | 68 +++++++++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index c28e220e..4a7fb944 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,25 +46,23 @@ class CFGDenoiserParams: ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) -callbacks_app_started = [] -callbacks_model_loaded = [] -callbacks_ui_tabs = [] -callbacks_ui_settings = [] -callbacks_before_image_saved = [] -callbacks_image_saved = [] -callbacks_cfg_denoiser = [] +__callback_map = dict( + callbacks_app_started=[], + callbacks_model_loaded=[], + callbacks_ui_tabs=[], + callbacks_ui_settings=[], + callbacks_before_image_saved=[], + callbacks_image_saved=[], + callbacks_cfg_denoiser=[] +) def clear_callbacks(): - callbacks_model_loaded.clear() - callbacks_ui_tabs.clear() - callbacks_ui_settings.clear() - callbacks_before_image_saved.clear() - callbacks_image_saved.clear() - callbacks_cfg_denoiser.clear() + for callback_list in __callback_map.values(): + callback_list.clear() def app_started_callback(demo: Optional[Blocks], app: FastAPI): - for c in callbacks_app_started: + for c in __callback_map['callbacks_app_started']: try: c.callback(demo, app) except Exception: @@ -72,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): def model_loaded_callback(sd_model): - for c in callbacks_model_loaded: + for c in __callback_map['callbacks_model_loaded']: try: c.callback(sd_model) except Exception: @@ -82,7 +80,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for c in callbacks_ui_tabs: + for c in __callback_map['callbacks_ui_tabs']: try: res += c.callback() or [] except Exception: @@ -92,7 +90,7 @@ def ui_tabs_callback(): def ui_settings_callback(): - for c in callbacks_ui_settings: + for c in __callback_map['callbacks_ui_settings']: try: c.callback() except Exception: @@ -100,7 +98,7 @@ def ui_settings_callback(): def before_image_saved_callback(params: ImageSaveParams): - for c in callbacks_before_image_saved: + for c in __callback_map['callbacks_before_image_saved']: try: c.callback(params) except Exception: @@ -108,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams): - for c in callbacks_image_saved: + for c in __callback_map['callbacks_image_saved']: try: c.callback(params) except Exception: @@ -116,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams): def cfg_denoiser_callback(params: CFGDenoiserParams): - for c in callbacks_cfg_denoiser: + for c in __callback_map['callbacks_cfg_denoiser']: try: c.callback(params) except Exception: @@ -129,17 +127,33 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) + +def remove_current_script_callbacks(): + stack = [x for x in inspect.stack() if x.filename != __file__] + filename = stack[0].filename if len(stack) > 0 else 'unknown file' + if filename == 'unknown file': + return + for callback_list in __callback_map.values(): + for callback_to_remove in [cb for cb in callback_list if cb.script == filename]: + callback_list.remove(callback_to_remove) + + +def remove_callbacks_for_function(callback_func): + for callback_list in __callback_map.values(): + for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]: + callback_list.remove(callback_to_remove) + def on_app_started(callback): """register a function to be called when the webui started, the gradio `Block` component and fastapi `FastAPI` object are passed as the arguments""" - add_callback(callbacks_app_started, callback) + add_callback(__callback_map['callbacks_app_started'], callback) def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" - add_callback(callbacks_model_loaded, callback) + add_callback(__callback_map['callbacks_model_loaded'], callback) def on_ui_tabs(callback): @@ -152,13 +166,13 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI elem_id is HTML id for the tab """ - add_callback(callbacks_ui_tabs, callback) + add_callback(__callback_map['callbacks_ui_tabs'], callback) def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ - add_callback(callbacks_ui_settings, callback) + add_callback(__callback_map['callbacks_ui_settings'], callback) def on_before_image_saved(callback): @@ -166,7 +180,7 @@ def on_before_image_saved(callback): The callback is called with one argument: - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. """ - add_callback(callbacks_before_image_saved, callback) + add_callback(__callback_map['callbacks_before_image_saved'], callback) def on_image_saved(callback): @@ -174,7 +188,7 @@ def on_image_saved(callback): The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ - add_callback(callbacks_image_saved, callback) + add_callback(__callback_map['callbacks_image_saved'], callback) def on_cfg_denoiser(callback): @@ -182,5 +196,5 @@ def on_cfg_denoiser(callback): The callback is called with one argument: - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(callbacks_cfg_denoiser, callback) + add_callback(__callback_map['callbacks_cfg_denoiser'], callback) -- cgit v1.2.1 From de64146ad2fc2030a4cd3545676f9e18c93b8b18 Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Wed, 2 Nov 2022 21:30:50 +0300 Subject: add number of itter --- modules/processing.py | 2 +- modules/scripts.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 72a2ee4e..17f4a5ec 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -510,7 +510,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: break if p.scripts is not None: - p.scripts.process_one(p) + p.scripts.process_one(p, n) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) diff --git a/modules/scripts.py b/modules/scripts.py index 9f82efea..7aa0d56a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -70,7 +70,7 @@ class Script: pass - def process_one(self, p, *args): + def process_one(self, p, n, *args): """ Same as process(), but called for every iteration """ @@ -301,11 +301,11 @@ class ScriptRunner: print(f"Error running process: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - def process_one(self, p): + def process_one(self, p, n): for script in self.alwayson_scripts: try: script_args = p.script_args[script.args_from:script.args_to] - script.process_one(p, *script_args) + script.process_one(p, n, *script_args) except Exception: print(f"Error running process_one: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) -- cgit v1.2.1 From 2ac25ea64f31fd0e7dea35d27a52f3646618c3b6 Mon Sep 17 00:00:00 2001 From: digburn Date: Wed, 2 Nov 2022 21:52:23 +0000 Subject: fix: Add required parameter to API extras route --- modules/api/models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 9ee42a17..9069c0ac 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -131,6 +131,7 @@ class ExtrasBaseRequest(BaseModel): upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") + upscale_first: bool = Field(default=True, title="Upscale first", description="Should the upscaler run before restoring faces?") class ExtraBaseResponse(BaseModel): html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") -- cgit v1.2.1 From 313e14de04d9955c6ad077341feceb0fc7f2f1d3 Mon Sep 17 00:00:00 2001 From: Chris OBryan <13701027+cobryan05@users.noreply.github.com> Date: Wed, 2 Nov 2022 21:37:43 -0500 Subject: extras - skip unnecessary second hash of image There is no need to re-hash the input image each iteration of the loop. This also reverts PR #4026 as it was determined the cache hits it avoids were actually valid. --- modules/extras.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 8e2ab35c..71b93a06 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -136,12 +136,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) for upscaler in params: upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), + cache_key = LruCache.Key(image_hash=image_hash, info_hash=hash(info), - args_hash=hash((upscale_args, upscale_first))) + args_hash=hash(upscale_args)) cached_entry = cached_images.get(cache_key) if cached_entry is None: res = upscale(image, *upscale_args) -- cgit v1.2.1 From 7a2e36b583ef9eaefa44322e16faff6f9f1af169 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Thu, 3 Nov 2022 00:51:22 -0300 Subject: Add config and lists endpoints --- modules/api/api.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++--- modules/api/models.py | 70 +++++++++++++++++++++++++++++++++++-- 2 files changed, 159 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 71c9c160..ed2dce5d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -2,14 +2,17 @@ import base64 import io import time import uvicorn -from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image -from fastapi import APIRouter, Depends, HTTPException +from threading import Lock +from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image +from fastapi import APIRouter, Depends, FastAPI, HTTPException import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid +from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo - +from modules.sd_models import checkpoints_list +from modules.realesrgan_model import get_realesrgan_models +from typing import List def upscaler_to_index(name: str): try: @@ -37,7 +40,7 @@ def encode_pil_to_base64(image): class Api: - def __init__(self, app, queue_lock): + def __init__(self, app: FastAPI, queue_lock: Lock): self.router = APIRouter() self.app = app self.queue_lock = queue_lock @@ -48,6 +51,19 @@ class Api: self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) + self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) + self.app.add_api_route("/sdapi/v1/info", self.get_info, methods=["GET"]) + self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) + self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) + self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) + self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) + self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) + self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) + self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) + self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) + self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -190,6 +206,77 @@ class Api: shared.state.interrupt() return {} + + def get_config(self): + options = {} + for key in shared.opts.data.keys(): + metadata = shared.opts.data_labels.get(key) + if(metadata is not None): + options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)}) + else: + options.update({key: shared.opts.data.get(key, None)}) + + return options + + def set_config(self, req: OptionsModel): + reqDict = vars(req) + for o in reqDict: + setattr(shared.opts, o, reqDict[o]) + + shared.opts.save(shared.config_filename) + return + + def get_cmd_flags(self): + return vars(shared.cmd_opts) + + def get_info(self): + + return { + "hypernetworks": [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks], + "face_restorers": [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers], + "realesrgan_models":[{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)], + "promp_styles":[shared.prompt_styles.styles[k] for k in shared.prompt_styles.styles], + "artists_categories": shared.artist_db.cats, + # "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] + } + + def get_samplers(self): + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + + def get_upscalers(self): + upscalers = [] + + for upscaler in shared.sd_upscalers: + u = upscaler.scaler + upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) + + return upscalers + + def get_sd_models(self): + return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] + + def get_hypernetworks(self): + return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] + + def get_face_restorers(self): + return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers] + + def get_realesrgan_models(self): + return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)] + + def get_promp_styles(self): + styleList = [] + for k in shared.prompt_styles.styles: + style = shared.prompt_styles.styles[k] + styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) + + return styleList + + def get_artists_categories(self): + return shared.artist_db.cats + + def get_artists(self): + return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 9ee42a17..b54b188a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,11 +1,10 @@ import inspect -from click import prompt from pydantic import BaseModel, Field, create_model -from typing import Any, Optional +from typing import Any, Optional, Union from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img -from modules.shared import sd_upscalers +from modules.shared import sd_upscalers, opts, parser API_NOT_ALLOWED = [ "self", @@ -165,3 +164,68 @@ class ProgressResponse(BaseModel): eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + +fields = {} +for key, value in opts.data.items(): + metadata = opts.data_labels.get(key) + optType = opts.typemap.get(type(value), type(value)) + + if (metadata is not None): + fields.update({key: (Optional[optType], Field( + default=metadata.default ,description=metadata.label))}) + else: + fields.update({key: (Optional[optType], Field())}) + +OptionsModel = create_model("Options", **fields) + +flags = {} +_options = vars(parser)['_option_string_actions'] +for key in _options: + if(_options[key].dest != 'help'): + flag = _options[key] + _type = str + if(_options[key].default != None): _type = type(_options[key].default) + flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) + +FlagsModel = create_model("Flags", **flags) + +class SamplerItem(BaseModel): + name: str = Field(title="Name") + aliases: list[str] = Field(title="Aliases") + options: dict[str, str] = Field(title="Options") + +class UpscalerItem(BaseModel): + name: str = Field(title="Name") + model_name: str | None = Field(title="Model Name") + model_path: str | None = Field(title="Path") + model_url: str | None = Field(title="URL") + +class SDModelItem(BaseModel): + title: str = Field(title="Title") + model_name: str = Field(title="Model Name") + hash: str = Field(title="Hash") + filename: str = Field(title="Filename") + config: str = Field(title="Config file") + +class HypernetworkItem(BaseModel): + name: str = Field(title="Name") + path: str | None = Field(title="Path") + +class FaceRestorerItem(BaseModel): + name: str = Field(title="Name") + cmd_dir: str | None = Field(title="Path") + +class RealesrganItem(BaseModel): + name: str = Field(title="Name") + path: str | None = Field(title="Path") + scale: int | None = Field(title="Scale") + +class PromptStyleItem(BaseModel): + name: str = Field(title="Name") + prompt: str | None = Field(title="Prompt") + negative_prompt: str | None = Field(title="Negative Prompt") + +class ArtistItem(BaseModel): + name: str = Field(title="Name") + score: float = Field(title="Score") + category: str = Field(title="Category") \ No newline at end of file -- cgit v1.2.1 From 743fffa3d6c2e9e6bb5f48093a4c88f3b53e001d Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Thu, 3 Nov 2022 00:52:01 -0300 Subject: Remove unused endpoint --- modules/api/api.py | 12 ------------ 1 file changed, 12 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index ed2dce5d..a49f3755 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -54,7 +54,6 @@ class Api: self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) - self.app.add_api_route("/sdapi/v1/info", self.get_info, methods=["GET"]) self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) @@ -229,17 +228,6 @@ class Api: def get_cmd_flags(self): return vars(shared.cmd_opts) - def get_info(self): - - return { - "hypernetworks": [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks], - "face_restorers": [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers], - "realesrgan_models":[{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)], - "promp_styles":[shared.prompt_styles.styles[k] for k in shared.prompt_styles.styles], - "artists_categories": shared.artist_db.cats, - # "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - } - def get_samplers(self): return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] -- cgit v1.2.1 From e33d6cbddd08870e348d10a58af41fb677a39fd6 Mon Sep 17 00:00:00 2001 From: Ju1-js <40339350+Ju1-js@users.noreply.github.com> Date: Wed, 2 Nov 2022 21:04:49 -0700 Subject: Make extension manager Remote links open a new tab --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ab807722..a81de9a7 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -86,7 +86,7 @@ def extension_table(): code += f""" - {html.escape(ext.remote or '')} + {html.escape(ext.remote or '')} {ext_status} """ -- cgit v1.2.1 From 0b143c1163a96b193a4e8512be9c5831c661a50d Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:30:53 +0900 Subject: Separate .optim file from model --- modules/hypernetworks/hypernetwork.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8f74cdea..63c25de8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -161,6 +161,7 @@ class Hypernetwork: def save(self, filename): state_dict = {} + optimizer_saved_dict = {} for k, v in self.layers.items(): state_dict[k] = (v[0].state_dict(), v[1].state_dict()) @@ -175,9 +176,10 @@ class Hypernetwork: state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name if self.optimizer_name is not None: - state_dict['optimizer_name'] = self.optimizer_name + optimizer_saved_dict['optimizer_name'] = self.optimizer_name if self.optimizer_state_dict: - state_dict['optimizer_state_dict'] = self.optimizer_state_dict + optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict + torch.save(optimizer_saved_dict, filename + '.optim') torch.save(state_dict, filename) @@ -198,9 +200,11 @@ class Hypernetwork: print(f"Layer norm is set to {self.add_layer_norm}") self.use_dropout = state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}") - self.optimizer_name = state_dict.get('optimizer_name', 'AdamW') + + optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} + self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print(f"Optimizer name is {self.optimizer_name}") - self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None) + self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) if self.optimizer_state_dict: print("Loaded existing optimizer from checkpoint") else: -- cgit v1.2.1 From 1764ac3c8bc482bd575987850e96630d9115e51a Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:49:26 +0900 Subject: use hash to check valid optim --- modules/hypernetworks/hypernetwork.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 63c25de8..4230b8cf 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -177,12 +177,13 @@ class Hypernetwork: state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name if self.optimizer_name is not None: optimizer_saved_dict['optimizer_name'] = self.optimizer_name + + torch.save(state_dict, filename) if self.optimizer_state_dict: + optimizer_saved_dict['hash'] = sd_models.model_hash(filename) optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(optimizer_saved_dict, filename + '.optim') - torch.save(state_dict, filename) - def load(self, filename): self.filename = filename if self.name is None: @@ -204,7 +205,10 @@ class Hypernetwork: optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print(f"Optimizer name is {self.optimizer_name}") - self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): + self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + else: + self.optimizer_state_dict = None if self.optimizer_state_dict: print("Loaded existing optimizer from checkpoint") else: @@ -229,7 +233,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name] = filename + res[name + f"({sd_models.model_hash(filename)})"] = filename return res @@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: hypernetwork_dir = None + hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True) -- cgit v1.2.1 From 86b7fc6e5ed56327fa12b444ca2444b13eb98aa8 Mon Sep 17 00:00:00 2001 From: thesved <2893181+thesved@users.noreply.github.com> Date: Thu, 3 Nov 2022 19:44:47 +0100 Subject: Make DDIM and PLMS work on Mac OS Fix register_buffer error on Mac OS --- modules/sd_hijack_inpainting.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index fd92a335..202b42cf 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,4 +1,5 @@ import torch +import modules.devices as devices from einops import repeat from omegaconf import ListConfig @@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion): self.masked_image_key = masked_image_key assert self.masked_image_key in concat_keys self.concat_keys = concat_keys + + +# ================================================================================================= +# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here +# ================================================================================================= +def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + optimal_type = devices.get_optimal_device() + if attr.device != optimal_type: + if getattr(torch, 'has_mps', False): + attr = attr.to(device="mps", dtype=torch.float32) + else: + attr = attr.to(optimal_type) + setattr(self, name, attr) def should_hijack_inpainting(checkpoint_info): @@ -326,6 +341,8 @@ def do_inpainting_hijack(): ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms - ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file + ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms + ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer -- cgit v1.2.1 From b2c48091db394c2b7d375a33f18d90c924cd4363 Mon Sep 17 00:00:00 2001 From: Gur Date: Fri, 4 Nov 2022 06:55:03 +0800 Subject: fixed api compatibility with python 3.8 --- modules/api/models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 9ee42a17..29a934ba 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -6,6 +6,7 @@ from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers +from typing import List API_NOT_ALLOWED = [ "self", @@ -109,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ).generate_model() class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: dict info: str @@ -146,10 +147,10 @@ class FileData(BaseModel): name: str = Field(title="File name") class ExtrasBatchImagesRequest(ExtrasBaseRequest): - imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") + images: List[str] = Field(title="Images", description="The generated images in base64 format.") class PNGInfoRequest(BaseModel): image: str = Field(title="Image", description="The base64 encoded PNG image") -- cgit v1.2.1 From 8eb64dab3e9e40531f6a3fa606a1c23a62987249 Mon Sep 17 00:00:00 2001 From: digburn <115176097+digburn@users.noreply.github.com> Date: Fri, 4 Nov 2022 00:35:18 +0000 Subject: fix: correct default val of upscale_first to False --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 9069c0ac..68fb45c6 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -131,7 +131,7 @@ class ExtrasBaseRequest(BaseModel): upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") - upscale_first: bool = Field(default=True, title="Upscale first", description="Should the upscaler run before restoring faces?") + upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?") class ExtraBaseResponse(BaseModel): html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") -- cgit v1.2.1 From 3780ad3ad837dd406da39eebd5d91009b5a58445 Mon Sep 17 00:00:00 2001 From: digburn Date: Fri, 4 Nov 2022 00:40:21 +0000 Subject: fix: loading models without vae from cache --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5075fadb..ae427a5c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,8 +204,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU else: - vae_name = sd_vae.get_filename(vae_file) - print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") checkpoints_loaded.move_to_end(checkpoint_key) model.load_state_dict(checkpoints_loaded[checkpoint_key]) -- cgit v1.2.1 From e533ff61c1baa4ad047f9c8dc05c17b64ee89ddf Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 3 Nov 2022 22:28:22 -0700 Subject: Lift extras generate button a la #4246. --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 2609857e..6461002a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1052,6 +1052,8 @@ def create_ui(wrap_gradio_gpu_call): extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") show_extras_results = gr.Checkbox(label='Show result images', value=True) + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by'): upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) @@ -1079,8 +1081,6 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) submit.click( -- cgit v1.2.1 From 4dd898b8c15e342f817d3fb1c8dc9f2d5d111022 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 08:38:11 +0300 Subject: do not mess with components' visibility for scripts; instead create group components and show/hide those; this will break scripts that create invisible components and rely on UI but the earlier i make this change the better --- modules/scripts.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 533db45c..28ce07f4 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -18,6 +18,9 @@ class Script: args_to = None alwayson = False + """A gr.Group component that has all script's UI inside it""" + group = None + infotext_fields = None """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example @@ -218,8 +221,6 @@ class ScriptRunner: for control in controls: control.custom_script_source = os.path.basename(script.filename) - if not script.alwayson: - control.visible = False if script.infotext_fields is not None: self.infotext_fields += script.infotext_fields @@ -229,40 +230,41 @@ class ScriptRunner: script.args_to = len(inputs) for script in self.alwayson_scripts: - with gr.Group(): + with gr.Group() as group: create_script_ui(script, inputs, inputs_alwayson) + script.group = group + dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") dropdown.save_to_config = True inputs[0] = dropdown for script in self.selectable_scripts: - create_script_ui(script, inputs, inputs_alwayson) + with gr.Group(visible=False) as group: + create_script_ui(script, inputs, inputs_alwayson) + + script.group = group def select_script(script_index): - if 0 < script_index <= len(self.selectable_scripts): - script = self.selectable_scripts[script_index-1] - args_from = script.args_from - args_to = script.args_to - else: - args_from = 0 - args_to = 0 + selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None - return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] + return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] def init_field(title): + """called when an initial value is set from ui-config.json to show script's UI components""" + if title == 'None': return + script_index = self.titles.index(title) - script = self.selectable_scripts[script_index] - for i in range(script.args_from, script.args_to): - inputs[i].visible = True + self.selectable_scripts[script_index].group.visible = True dropdown.init_field = init_field + dropdown.change( fn=select_script, inputs=[dropdown], - outputs=inputs + outputs=[script.group for script in self.selectable_scripts] ) return inputs -- cgit v1.2.1 From f2b69709eaff88fc3a2bd49585556ec0883bf5ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 09:42:25 +0300 Subject: move option access checking to options class out of various places scattered through code --- modules/processing.py | 4 ++-- modules/shared.py | 11 +++++++++++ modules/ui.py | 20 +++++--------------- 3 files changed, 18 insertions(+), 17 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 2168208c..a46e592d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -418,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible + setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible res = process_images_inner(p) finally: for k, v in stored_opts.items(): - opts.data[k] = v + setattr(opts, k, v) return res diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..024c771a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -396,6 +396,15 @@ class Options: def __setattr__(self, key, value): if self.data is not None: if key in self.data or key in self.data_labels: + assert not cmd_opts.freeze_settings, "changing settings is disabled" + + comp_args = opts.data_labels[key].component_args + if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + raise RuntimeError(f"not possible to set {key} because it is restricted") + self.data[key] = value return @@ -412,6 +421,8 @@ class Options: return super(Options, self).__getattribute__(item) def save(self, filename): + assert not cmd_opts.freeze_settings, "saving settings is disabled" + with open(filename, "w", encoding="utf8") as file: json.dump(self.data, file, indent=4) diff --git a/modules/ui.py b/modules/ui.py index b2b1c854..633b56ef 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1438,8 +1438,6 @@ def create_ui(wrap_gradio_gpu_call): def run_settings(*args): changed = 0 - assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" - for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() @@ -1448,15 +1446,9 @@ def create_ui(wrap_gradio_gpu_call): if comp == dummy_component: continue - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - continue - - if cmd_opts.hide_ui_dir_config and key in restricted_opts: - continue - oldval = opts.data.get(key, None) - opts.data[key] = value + + setattr(opts, key, value) if oldval != value: if opts.data_labels[key].onchange is not None: @@ -1469,17 +1461,15 @@ def create_ui(wrap_gradio_gpu_call): return f'{changed} settings changed.', opts.dumpjson() def run_settings_single(value, key): - assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" - if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() oldval = opts.data.get(key, None) - if cmd_opts.hide_ui_dir_config and key in restricted_opts: + try: + setattr(opts, key, value) + except Exception: return gr.update(value=oldval), opts.dumpjson() - opts.data[key] = value - if oldval != value: if opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() -- cgit v1.2.1 From 0abb39f461baa343ae7c23abffb261e57c3168d4 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 15:47:19 +0900 Subject: resolve conflict - first revert --- modules/hypernetworks/hypernetwork.py | 123 ++++++++++++++-------------------- 1 file changed, 52 insertions(+), 71 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4230b8cf..674fcedd 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -21,7 +21,6 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm from collections import defaultdict, deque from statistics import stdev, mean -optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, } - activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) + activation_dict.update( + {cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if + inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', + add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -47,7 +49,7 @@ class HypernetworkModule(torch.nn.Module): for i in range(len(layer_structure) - 1): # Add a fully-connected layer - linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1]))) # Add an activation func if activation_func == "linear" or activation_func is None: @@ -59,7 +61,7 @@ class HypernetworkModule(torch.nn.Module): # Add layer normalization if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1]))) # Add dropout expect last layer if use_dropout and i < len(layer_structure) - 3: @@ -128,7 +130,8 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, + add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -140,13 +143,13 @@ class Hypernetwork: self.weight_init = weight_init self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout - self.optimizer_name = None - self.optimizer_state_dict = None for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -161,7 +164,6 @@ class Hypernetwork: def save(self, filename): state_dict = {} - optimizer_saved_dict = {} for k, v in self.layers.items(): state_dict[k] = (v[0].state_dict(), v[1].state_dict()) @@ -175,14 +177,8 @@ class Hypernetwork: state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name - if self.optimizer_name is not None: - optimizer_saved_dict['optimizer_name'] = self.optimizer_name torch.save(state_dict, filename) - if self.optimizer_state_dict: - optimizer_saved_dict['hash'] = sd_models.model_hash(filename) - optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict - torch.save(optimizer_saved_dict, filename + '.optim') def load(self, filename): self.filename = filename @@ -202,23 +198,13 @@ class Hypernetwork: self.use_dropout = state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}") - optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} - self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print(f"Optimizer name is {self.optimizer_name}") - if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): - self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) - else: - self.optimizer_state_dict = None - if self.optimizer_state_dict: - print("Loaded existing optimizer from checkpoint") - else: - print("No saved optimizer exists in checkpoint") - for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, + self.add_layer_norm, self.use_dropout), ) self.name = state_dict.get('name', self.name) @@ -233,7 +219,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name + f"({sd_models.model_hash(filename)})"] = filename + res[name] = filename return res @@ -330,7 +316,7 @@ def statistics(data): std = 0 else: std = stdev(data) - total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})" + total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})" recent_data = data[-32:] if len(recent_data) < 2: std = 0 @@ -340,7 +326,7 @@ def statistics(data): return total_information, recent_information -def report_statistics(loss_info:dict): +def report_statistics(loss_info: dict): keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) for key in keys: try: @@ -352,14 +338,18 @@ def report_statistics(loss_info:dict): print(e) - -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, + training_height, steps, create_image_every, save_hypernetwork_every, template_file, + preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, + preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, + save_hypernetwork_every, create_image_every, log_directory, + name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: hypernetwork_dir = None - hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True) @@ -395,39 +384,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=hypernetwork_name, + model=shared.sd_model, device=devices.device, + template_file=template_file, include_cond=True, + batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) size = len(ds.indexes) - loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + loss_dict = defaultdict(lambda: deque(maxlen=1024)) losses = torch.zeros((size,)) previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) - + weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True - # Here we use optimizer from saved HN, or we can specify as UI option. - if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict: - optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) - else: - print(f"Optimizer type {optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) - optimizer_name = 'AdamW' - if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. - try: - optimizer.load_state_dict(hypernetwork.optimizer_state_dict) - except RuntimeError as e: - print("Cannot resume from saved optimizer!") - print(e) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) steps_without_grad = 0 @@ -441,7 +425,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if len(loss_dict) > 0: previous_mean_losses = [i[-1] for i in loss_dict.values()] previous_mean_loss = mean(previous_mean_losses) - + scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: break @@ -460,7 +444,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() for entry in entries: loss_dict[entry.filename].append(loss.item()) - + optimizer.zero_grad() weights[0].grad = None loss.backward() @@ -475,9 +459,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log steps_done = hypernetwork.step + 1 - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - + if len(previous_mean_losses) > 1: std = stdev(previous_mean_losses) else: @@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') - hypernetwork.optimizer_name = optimizer_name - if shared.opts.save_optimizer_state: - hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) - hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", "learn_rate": scheduler.learn_rate @@ -529,7 +510,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log preview_text = p.prompt processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None + image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -537,7 +518,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, + shared.opts.samples_format, processed.infotexts[0], + p=p, forced_filename=forced_filename, + save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -551,15 +535,12 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" + report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') - hypernetwork.optimizer_name = optimizer_name - if shared.opts.save_optimizer_state: - hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) - del optimizer - hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + return hypernetwork, filename @@ -576,4 +557,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): hypernetwork.sd_checkpoint = old_sd_checkpoint hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name hypernetwork.name = old_hypernetwork_name - raise + raise \ No newline at end of file -- cgit v1.2.1 From 0d07cbfa15d34294a4fa22d74359cdd6fe2f799c Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 15:50:54 +0900 Subject: I blame code autocomplete --- modules/hypernetworks/hypernetwork.py | 76 +++++++++++++---------------------- 1 file changed, 27 insertions(+), 49 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 674fcedd..a11e01d6 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, } - activation_dict.update( - {cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if - inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) + activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -49,7 +46,7 @@ class HypernetworkModule(torch.nn.Module): for i in range(len(layer_structure) - 1): # Add a fully-connected layer - linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1]))) + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # Add an activation func if activation_func == "linear" or activation_func is None: @@ -61,7 +58,7 @@ class HypernetworkModule(torch.nn.Module): # Add layer normalization if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1]))) + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) # Add dropout expect last layer if use_dropout and i < len(layer_structure) - 3: @@ -130,8 +127,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, - add_layer_norm=False, use_dropout=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -146,10 +142,8 @@ class Hypernetwork: for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -196,15 +190,13 @@ class Hypernetwork: self.add_layer_norm = state_dict.get('is_layer_norm', False) print(f"Layer norm is set to {self.add_layer_norm}") self.use_dropout = state_dict.get('use_dropout', False) - print(f"Dropout usage is set to {self.use_dropout}") + print(f"Dropout usage is set to {self.use_dropout}" ) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), ) self.name = state_dict.get('name', self.name) @@ -316,7 +308,7 @@ def statistics(data): std = 0 else: std = stdev(data) - total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})" + total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})" recent_data = data[-32:] if len(recent_data) < 2: std = 0 @@ -326,7 +318,7 @@ def statistics(data): return total_information, recent_information -def report_statistics(loss_info: dict): +def report_statistics(loss_info:dict): keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) for key in keys: try: @@ -338,18 +330,14 @@ def report_statistics(loss_info: dict): print(e) -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, - training_height, steps, create_image_every, save_hypernetwork_every, template_file, - preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, - preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, - save_hypernetwork_every, create_image_every, log_directory, - name="hypernetwork") + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -384,29 +372,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, - height=training_height, - repeats=shared.opts.training_image_repeats_per_epoch, - placeholder_token=hypernetwork_name, - model=shared.sd_model, device=devices.device, - template_file=template_file, include_cond=True, - batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) size = len(ds.indexes) - loss_dict = defaultdict(lambda: deque(maxlen=1024)) + loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) - + weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True @@ -425,7 +407,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if len(loss_dict) > 0: previous_mean_losses = [i[-1] for i in loss_dict.values()] previous_mean_loss = mean(previous_mean_losses) - + scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: break @@ -444,7 +426,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() for entry in entries: loss_dict[entry.filename].append(loss.item()) - + optimizer.zero_grad() weights[0].grad = None loss.backward() @@ -459,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log steps_done = hypernetwork.step + 1 - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - + if len(previous_mean_losses) > 1: std = stdev(previous_mean_losses) else: @@ -510,7 +492,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log preview_text = p.prompt processed = processing.process_images(p) - image = processed.images[0] if len(processed.images) > 0 else None + image = processed.images[0] if len(processed.images)>0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -518,10 +500,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, - shared.opts.samples_format, processed.infotexts[0], - p=p, forced_filename=forced_filename, - save_to_dirs=False) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -535,7 +514,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - + report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') @@ -543,7 +522,6 @@ Last saved image: {html.escape(last_saved_image)}
return hypernetwork, filename - def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_hypernetwork_name = hypernetwork.name old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None @@ -557,4 +535,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): hypernetwork.sd_checkpoint = old_sd_checkpoint hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name hypernetwork.name = old_hypernetwork_name - raise \ No newline at end of file + raise -- cgit v1.2.1 From 283249d2390f0f3a1c8a55d5d9aa551e3e9b2f9c Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 15:57:17 +0900 Subject: apply --- modules/hypernetworks/hypernetwork.py | 54 +++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6e1a10cf..de8688a9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -22,6 +22,8 @@ from collections import defaultdict, deque from statistics import stdev, mean +optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -142,6 +144,8 @@ class Hypernetwork: self.use_dropout = use_dropout self.activate_output = activate_output self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True + self.optimizer_name = None + self.optimizer_state_dict = None for size in enable_sizes or []: self.layers[size] = ( @@ -163,6 +167,7 @@ class Hypernetwork: def save(self, filename): state_dict = {} + optimizer_saved_dict = {} for k, v in self.layers.items(): state_dict[k] = (v[0].state_dict(), v[1].state_dict()) @@ -178,8 +183,15 @@ class Hypernetwork: state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['activate_output'] = self.activate_output state_dict['last_layer_dropout'] = self.last_layer_dropout - + + if self.optimizer_name is not None: + optimizer_saved_dict['optimizer_name'] = self.optimizer_name + torch.save(state_dict, filename) + if self.optimizer_state_dict: + optimizer_saved_dict['hash'] = sd_models.model_hash(filename) + optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict + torch.save(optimizer_saved_dict, filename + '.optim') def load(self, filename): self.filename = filename @@ -202,6 +214,18 @@ class Hypernetwork: print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) + optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} + self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') + print(f"Optimizer name is {self.optimizer_name}") + if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): + self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + else: + self.optimizer_state_dict = None + if self.optimizer_state_dict: + print("Loaded existing optimizer from checkpoint") + else: + print("No saved optimizer exists in checkpoint") + for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( @@ -223,7 +247,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name] = filename + res[name + f"({sd_models.model_hash(filename)})"] = filename return res @@ -369,6 +393,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: hypernetwork_dir = None + hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True) @@ -404,8 +429,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True - # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... - optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + # Here we use optimizer from saved HN, or we can specify as UI option. + if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict: + optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) + else: + print(f"Optimizer type {optimizer_name} is not defined!") + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' + if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. + try: + optimizer.load_state_dict(hypernetwork.optimizer_state_dict) + except RuntimeError as e: + print("Cannot resume from saved optimizer!") + print(e) steps_without_grad = 0 @@ -467,7 +503,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", @@ -530,8 +570,12 @@ Last saved image: {html.escape(last_saved_image)}
report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) - + del optimizer + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. return hypernetwork, filename def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): -- cgit v1.2.1 From f5d394214d6ee74a682d0a1016bcbebc4b43c13a Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 16:04:03 +0900 Subject: split before declaring file name --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index de8688a9..9b6a3e62 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -382,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = "Initializing hypernetwork training..." shared.state.job_count = steps + hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) @@ -393,7 +394,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: hypernetwork_dir = None - hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0] if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True) -- cgit v1.2.1 From 1ca0bcd3a7003dd2c1324de7d97fd2a6fc5ddc53 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 16:09:19 +0900 Subject: only save if option is enabled --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 9b6a3e62..b1f308e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -188,7 +188,7 @@ class Hypernetwork: optimizer_saved_dict['optimizer_name'] = self.optimizer_name torch.save(state_dict, filename) - if self.optimizer_state_dict: + if shared.opts.save_optimizer_state and self.optimizer_state_dict: optimizer_saved_dict['hash'] = sd_models.model_hash(filename) optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(optimizer_saved_dict, filename + '.optim') -- cgit v1.2.1 From ccf1a15412ef6b518f9f54cc26a0ee5edf458108 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 10:16:19 +0300 Subject: add an option to enable installing extensions with --listen or --share --- modules/shared.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 024c771a..0a39cdf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision", parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") +parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) @@ -99,7 +100,7 @@ restricted_opts = { "outdir_save", } -cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) -- cgit v1.2.1 From 321e13ca176b256177c4a752d1f2bbee79b5532e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 10:35:30 +0300 Subject: produce a readable error message when setting an option fails on the settings screen --- modules/ui.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 633b56ef..3ac7540c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1439,8 +1439,7 @@ def create_ui(wrap_gradio_gpu_call): changed = 0 for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): - return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp == dummy_component: @@ -1458,7 +1457,7 @@ def create_ui(wrap_gradio_gpu_call): opts.save(shared.config_filename) - return f'{changed} settings changed.', opts.dumpjson() + return opts.dumpjson(), f'{changed} settings changed.' def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): @@ -1622,9 +1621,9 @@ def create_ui(wrap_gradio_gpu_call): text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( - fn=run_settings, + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), inputs=components, - outputs=[result, text_settings], + outputs=[text_settings, result], ) for i, k, item in quicksettings_list: -- cgit v1.2.1 From f674c488d9701e577e2aaf25e331fb44ada4f1ef Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 10:45:34 +0300 Subject: bugfix: save image for hires fix BEFORE upscaling latent space --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index a46e592d..7a2fc218 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -665,17 +665,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") if opts.use_scale_latent_for_hires_fix: + for i in range(samples.shape[0]): + save_intermediate(samples, i) + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - + # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples) else: image_conditioning = self.txt2img_image_conditioning(samples) - - for i in range(samples.shape[0]): - save_intermediate(samples, i) else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) -- cgit v1.2.1 From 7278897982bfb640ee95f144c97ed25fb3f77ea3 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Fri, 4 Nov 2022 17:12:28 +0900 Subject: Update shared.py --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 4d6e1c8b..6e7a02e0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,7 +309,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.1 From 99043f33606d3057f83ea52a403e10cd29d1f7e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 11:20:42 +0300 Subject: fix one of previous merges breaking the program --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 63e07a12..34c57bfa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -167,6 +167,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): sd_vae.restore_base_vae(model) checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") -- cgit v1.2.1 From eeb07330131012c0294afb79165b90270679b9c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 11:21:40 +0300 Subject: change process_one virtual function for script to process_batch, add extra args and docs --- modules/processing.py | 2 +- modules/scripts.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index e20d8fc4..03c9143d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -502,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: break if p.scripts is not None: - p.scripts.process_one(p, n) + p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) diff --git a/modules/scripts.py b/modules/scripts.py index 75e47cd2..366c90d7 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -73,9 +73,15 @@ class Script: pass - def process_one(self, p, n, *args): + def process_batch(self, p, *args, **kwargs): """ - Same as process(), but called for every iteration + Same as process(), but called for every batch. + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch """ pass @@ -303,13 +309,13 @@ class ScriptRunner: print(f"Error running process: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - def process_one(self, p, n): + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: script_args = p.script_args[script.args_from:script.args_to] - script.process_one(p, n, *script_args) + script.process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running process_one: {script.filename}", file=sys.stderr) + print(f"Error running process_batch: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) def postprocess(self, p, processed): -- cgit v1.2.1 From 39541d7725bc42f456a604b07c50aba503a5a09a Mon Sep 17 00:00:00 2001 From: Fampai <> Date: Fri, 4 Nov 2022 04:50:22 -0400 Subject: Fixes race condition in training when VAE is unloaded set_current_image can attempt to use the VAE when it is unloaded to the CPU while training --- modules/hypernetworks/hypernetwork.py | 4 ++++ modules/textual_inversion/textual_inversion.py | 5 +++++ 2 files changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6e1a10cf..fcb96059 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -390,7 +390,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + old_parallel_processing_allowed = shared.parallel_processing_allowed + if unload: + shared.parallel_processing_allowed = False shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) @@ -531,6 +534,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) + shared.parallel_processing_allowed = old_parallel_processing_allowed return hypernetwork, filename diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 0aeb0459..55892c57 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -273,7 +273,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + + old_parallel_processing_allowed = shared.parallel_processing_allowed + if unload: + shared.parallel_processing_allowed = False shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True @@ -410,6 +414,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) shared.sd_model.first_stage_model.to(devices.device) + shared.parallel_processing_allowed = old_parallel_processing_allowed return embedding, filename -- cgit v1.2.1 From 821e2b883dbb42a187bc37379175cd55b7cd7e81 Mon Sep 17 00:00:00 2001 From: TinkTheBoush Date: Fri, 4 Nov 2022 19:39:03 +0900 Subject: change option position to Training setting --- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/shared.py | 1 + modules/textual_inversion/dataset.py | 5 ++--- modules/textual_inversion/textual_inversion.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7630fb81..a11e01d6 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -331,7 +331,7 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -376,7 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/shared.py b/modules/shared.py index 1ccb269a..e1d9bdf1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -290,6 +290,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "shuffle_tags": OptionInfo(False, "Shuffleing tags by "," when create texts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e9d97cc1..df278dc2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,7 +24,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -33,7 +33,6 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) - self.shuffle_tags = shuffle_tags self.dataset = [] @@ -99,7 +98,7 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - if self.tag_shuffle: + if shared.opts.shuffle_tags: tags = filename_text.split(',') random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 82dde931..0aeb0459 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -272,7 +272,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) if unload: shared.sd_model.first_stage_model.to(devices.cpu) -- cgit v1.2.1 From 45b65e87e0ef64b3e457f7d20c62d591cdcd0e7b Mon Sep 17 00:00:00 2001 From: TinkTheBoush Date: Fri, 4 Nov 2022 19:48:28 +0900 Subject: remove ui option --- modules/ui.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6f3836c6..45cd8c3f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1269,7 +1269,6 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) - shuffle_tags = gr.Checkbox(label='Shuffleing tags by "," when create texts', value=True) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1364,7 +1363,6 @@ def create_ui(wrap_gradio_gpu_call): template_file, save_image_with_stored_embedding, preview_from_txt2img, - shuffle_tags, *txt2img_preview_params, ], outputs=[ @@ -1389,7 +1387,6 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every, template_file, preview_from_txt2img, - shuffle_tags, *txt2img_preview_params, ], outputs=[ -- cgit v1.2.1 From fd62727893f9face287b0a9620251afaa38a627d Mon Sep 17 00:00:00 2001 From: Isaac Poulton Date: Fri, 4 Nov 2022 18:34:35 +0700 Subject: Sort hypernetworks --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6e1a10cf..f1f04a70 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -224,7 +224,7 @@ def list_hypernetworks(path): # Prevent a hypothetical "None.pt" from being listed. if name != "None": res[name] = filename - return res + return dict(sorted(res.items())) def load_hypernetwork(filename): -- cgit v1.2.1 From c3cd0d7a86f35a5bfc58fdc3ecfaf203c0aee06f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 4 Nov 2022 12:19:16 +0000 Subject: Should be one underscore for module privates not two --- modules/script_callbacks.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 4a7fb944..83da7ca4 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,7 +46,7 @@ class CFGDenoiserParams: ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) -__callback_map = dict( +_callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], callbacks_ui_tabs=[], @@ -58,11 +58,11 @@ __callback_map = dict( def clear_callbacks(): - for callback_list in __callback_map.values(): + for callback_list in _callback_map.values(): callback_list.clear() def app_started_callback(demo: Optional[Blocks], app: FastAPI): - for c in __callback_map['callbacks_app_started']: + for c in _callback_map['callbacks_app_started']: try: c.callback(demo, app) except Exception: @@ -70,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): def model_loaded_callback(sd_model): - for c in __callback_map['callbacks_model_loaded']: + for c in _callback_map['callbacks_model_loaded']: try: c.callback(sd_model) except Exception: @@ -80,7 +80,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for c in __callback_map['callbacks_ui_tabs']: + for c in _callback_map['callbacks_ui_tabs']: try: res += c.callback() or [] except Exception: @@ -90,7 +90,7 @@ def ui_tabs_callback(): def ui_settings_callback(): - for c in __callback_map['callbacks_ui_settings']: + for c in _callback_map['callbacks_ui_settings']: try: c.callback() except Exception: @@ -98,7 +98,7 @@ def ui_settings_callback(): def before_image_saved_callback(params: ImageSaveParams): - for c in __callback_map['callbacks_before_image_saved']: + for c in _callback_map['callbacks_before_image_saved']: try: c.callback(params) except Exception: @@ -106,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams): - for c in __callback_map['callbacks_image_saved']: + for c in _callback_map['callbacks_image_saved']: try: c.callback(params) except Exception: @@ -114,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams): def cfg_denoiser_callback(params: CFGDenoiserParams): - for c in __callback_map['callbacks_cfg_denoiser']: + for c in _callback_map['callbacks_cfg_denoiser']: try: c.callback(params) except Exception: @@ -133,13 +133,13 @@ def remove_current_script_callbacks(): filename = stack[0].filename if len(stack) > 0 else 'unknown file' if filename == 'unknown file': return - for callback_list in __callback_map.values(): + for callback_list in _callback_map.values(): for callback_to_remove in [cb for cb in callback_list if cb.script == filename]: callback_list.remove(callback_to_remove) def remove_callbacks_for_function(callback_func): - for callback_list in __callback_map.values(): + for callback_list in _callback_map.values(): for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]: callback_list.remove(callback_to_remove) @@ -147,13 +147,13 @@ def remove_callbacks_for_function(callback_func): def on_app_started(callback): """register a function to be called when the webui started, the gradio `Block` component and fastapi `FastAPI` object are passed as the arguments""" - add_callback(__callback_map['callbacks_app_started'], callback) + add_callback(_callback_map['callbacks_app_started'], callback) def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" - add_callback(__callback_map['callbacks_model_loaded'], callback) + add_callback(_callback_map['callbacks_model_loaded'], callback) def on_ui_tabs(callback): @@ -166,13 +166,13 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI elem_id is HTML id for the tab """ - add_callback(__callback_map['callbacks_ui_tabs'], callback) + add_callback(_callback_map['callbacks_ui_tabs'], callback) def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ - add_callback(__callback_map['callbacks_ui_settings'], callback) + add_callback(_callback_map['callbacks_ui_settings'], callback) def on_before_image_saved(callback): @@ -180,7 +180,7 @@ def on_before_image_saved(callback): The callback is called with one argument: - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. """ - add_callback(__callback_map['callbacks_before_image_saved'], callback) + add_callback(_callback_map['callbacks_before_image_saved'], callback) def on_image_saved(callback): @@ -188,7 +188,7 @@ def on_image_saved(callback): The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ - add_callback(__callback_map['callbacks_image_saved'], callback) + add_callback(_callback_map['callbacks_image_saved'], callback) def on_cfg_denoiser(callback): @@ -196,5 +196,4 @@ def on_cfg_denoiser(callback): The callback is called with one argument: - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(__callback_map['callbacks_cfg_denoiser'], callback) - + add_callback(_callback_map['callbacks_cfg_denoiser'], callback) -- cgit v1.2.1 From f316280ad3634a2343b086a6de0bfcd473e18599 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 16:48:40 +0300 Subject: fix the error that prevents from setting some options --- modules/shared.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..962115f6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -406,7 +406,8 @@ class Options: if key in self.data or key in self.data_labels: assert not cmd_opts.freeze_settings, "changing settings is disabled" - comp_args = opts.data_labels[key].component_args + info = opts.data_labels.get(key, None) + comp_args = info.component_args if info else None if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: raise RuntimeError(f"not possible to set {key} because it is restricted") -- cgit v1.2.1 From 116bcf730ade8d3ac5d76d04c5887b6bba000970 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 16:48:46 +0300 Subject: disable setting options via API until it is fixed by the author --- modules/api/api.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a49f3755..8a7ab2f5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -218,6 +218,10 @@ class Api: return options def set_config(self, req: OptionsModel): + # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will + # overwrite all options with default values. + raise RuntimeError('Setting options via API is not supported') + reqDict = vars(req) for o in reqDict: setattr(shared.opts, o, reqDict[o]) -- cgit v1.2.1 From 08feb4c364e8b2aed929fd7d22dfa21a93d78b2c Mon Sep 17 00:00:00 2001 From: Isaac Poulton Date: Fri, 4 Nov 2022 20:53:11 +0700 Subject: Sort straight out of the glob --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index f1f04a70..a441ab10 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -219,12 +219,12 @@ class Hypernetwork: def list_hypernetworks(path): res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": res[name] = filename - return dict(sorted(res.items())) + return res def load_hypernetwork(filename): -- cgit v1.2.1 From 5844ef8a9a165e0f456a4658bda830282cf5a55e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 4 Nov 2022 16:02:25 +0000 Subject: remove private underscore indicator --- modules/script_callbacks.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 83da7ca4..74dfb880 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,7 +46,7 @@ class CFGDenoiserParams: ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) -_callback_map = dict( +callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], callbacks_ui_tabs=[], @@ -58,11 +58,11 @@ _callback_map = dict( def clear_callbacks(): - for callback_list in _callback_map.values(): + for callback_list in callback_map.values(): callback_list.clear() def app_started_callback(demo: Optional[Blocks], app: FastAPI): - for c in _callback_map['callbacks_app_started']: + for c in callback_map['callbacks_app_started']: try: c.callback(demo, app) except Exception: @@ -70,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): def model_loaded_callback(sd_model): - for c in _callback_map['callbacks_model_loaded']: + for c in callback_map['callbacks_model_loaded']: try: c.callback(sd_model) except Exception: @@ -80,7 +80,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for c in _callback_map['callbacks_ui_tabs']: + for c in callback_map['callbacks_ui_tabs']: try: res += c.callback() or [] except Exception: @@ -90,7 +90,7 @@ def ui_tabs_callback(): def ui_settings_callback(): - for c in _callback_map['callbacks_ui_settings']: + for c in callback_map['callbacks_ui_settings']: try: c.callback() except Exception: @@ -98,7 +98,7 @@ def ui_settings_callback(): def before_image_saved_callback(params: ImageSaveParams): - for c in _callback_map['callbacks_before_image_saved']: + for c in callback_map['callbacks_before_image_saved']: try: c.callback(params) except Exception: @@ -106,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams): - for c in _callback_map['callbacks_image_saved']: + for c in callback_map['callbacks_image_saved']: try: c.callback(params) except Exception: @@ -114,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams): def cfg_denoiser_callback(params: CFGDenoiserParams): - for c in _callback_map['callbacks_cfg_denoiser']: + for c in callback_map['callbacks_cfg_denoiser']: try: c.callback(params) except Exception: @@ -133,13 +133,13 @@ def remove_current_script_callbacks(): filename = stack[0].filename if len(stack) > 0 else 'unknown file' if filename == 'unknown file': return - for callback_list in _callback_map.values(): + for callback_list in callback_map.values(): for callback_to_remove in [cb for cb in callback_list if cb.script == filename]: callback_list.remove(callback_to_remove) def remove_callbacks_for_function(callback_func): - for callback_list in _callback_map.values(): + for callback_list in callback_map.values(): for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]: callback_list.remove(callback_to_remove) @@ -147,13 +147,13 @@ def remove_callbacks_for_function(callback_func): def on_app_started(callback): """register a function to be called when the webui started, the gradio `Block` component and fastapi `FastAPI` object are passed as the arguments""" - add_callback(_callback_map['callbacks_app_started'], callback) + add_callback(callback_map['callbacks_app_started'], callback) def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" - add_callback(_callback_map['callbacks_model_loaded'], callback) + add_callback(callback_map['callbacks_model_loaded'], callback) def on_ui_tabs(callback): @@ -166,13 +166,13 @@ def on_ui_tabs(callback): title is tab text displayed to user in the UI elem_id is HTML id for the tab """ - add_callback(_callback_map['callbacks_ui_tabs'], callback) + add_callback(callback_map['callbacks_ui_tabs'], callback) def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ - add_callback(_callback_map['callbacks_ui_settings'], callback) + add_callback(callback_map['callbacks_ui_settings'], callback) def on_before_image_saved(callback): @@ -180,7 +180,7 @@ def on_before_image_saved(callback): The callback is called with one argument: - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. """ - add_callback(_callback_map['callbacks_before_image_saved'], callback) + add_callback(callback_map['callbacks_before_image_saved'], callback) def on_image_saved(callback): @@ -188,7 +188,7 @@ def on_image_saved(callback): The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ - add_callback(_callback_map['callbacks_image_saved'], callback) + add_callback(callback_map['callbacks_image_saved'], callback) def on_cfg_denoiser(callback): @@ -196,4 +196,4 @@ def on_cfg_denoiser(callback): The callback is called with one argument: - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ - add_callback(_callback_map['callbacks_cfg_denoiser'], callback) + add_callback(callback_map['callbacks_cfg_denoiser'], callback) -- cgit v1.2.1 From 0d7e01d9950e013784c4b77c05aa7583ea69edc8 Mon Sep 17 00:00:00 2001 From: innovaciones Date: Fri, 4 Nov 2022 12:14:32 -0600 Subject: Open extensions links in new tab Fixed for "Available" tab --- modules/ui_extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index a81de9a7..8e0d41d5 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -188,7 +188,7 @@ def refresh_available_extensions_from_data(): code += f""" - {html.escape(name)} + {html.escape(name)} {html.escape(description)} {install_code} -- cgit v1.2.1 From b8435e632f7ba0da12a2c8e9c788dda519279d24 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sat, 5 Nov 2022 02:36:47 +0800 Subject: add --cors-allow-origins cmd opt --- modules/shared.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..e83cbcdf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) cmd_opts = parser.parse_args() restricted_opts = { @@ -147,9 +148,9 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.show_progress_every_n_steps == -1: self.do_set_current_image() - + self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 @@ -198,7 +199,7 @@ class State: return if self.current_latent is None: return - + if opts.show_progress_grid: self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) else: -- cgit v1.2.1 From 467d8b967b5d1b1984ab113bec3fff217736e7ac Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sat, 5 Nov 2022 04:24:42 +0900 Subject: Fix errors from commit f2b697 with --hide-ui-dir-config https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/f2b69709eaff88fc3a2bd49585556ec0883bf5ea --- modules/ui.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 4c2829af..76ca9b07 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1446,17 +1446,19 @@ def create_ui(wrap_gradio_gpu_call): continue oldval = opts.data.get(key, None) - - setattr(opts, key, value) - + try: + setattr(opts, key, value) + except RuntimeError: + continue if oldval != value: if opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() changed += 1 - - opts.save(shared.config_filename) - + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{changed} settings changed without save.' return opts.dumpjson(), f'{changed} settings changed.' def run_settings_single(value, key): -- cgit v1.2.1 From 30b1bcc64e67ad50c5d3af3a6fe1bd1e9553f34e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 22:56:18 +0300 Subject: fix upscale loop erroneously applied multiple times --- modules/upscaler.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/upscaler.py b/modules/upscaler.py index 83fde7ca..c4e6e6bd 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -57,10 +57,18 @@ class Upscaler: self.scale = scale dest_w = img.width * scale dest_h = img.height * scale + for i in range(3): - if img.width > dest_w and img.height > dest_h: - break + shape = (img.width, img.height) + img = self.do_upscale(img, selected_model) + + if shape == (img.width, img.height): + break + + if img.width >= dest_w and img.height >= dest_h: + break + if img.width != dest_w or img.height != dest_h: img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) -- cgit v1.2.1 From 6008c0773ea575353f9b87da8a58454e20cc7857 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Fri, 4 Nov 2022 23:03:05 +0000 Subject: Add support for new DPM-Solver++ samplers --- modules/sd_samplers.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index c7c414ef..7ece6556 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -29,6 +29,10 @@ samplers_k_diffusion = [ ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), + ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), + ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM-Solver++(2S) Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), + ('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ] samplers_data_k_diffusion = [ -- cgit v1.2.1 From f92dc505a013af9e385c7edbdf97539be62503d6 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Fri, 4 Nov 2022 23:12:48 +0000 Subject: Fix name --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7ece6556..b28a2e4c 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -31,7 +31,7 @@ samplers_k_diffusion = [ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM-Solver++(2S) Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), + ('DPM-Solver++(2S) a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), ('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ] -- cgit v1.2.1 From 1b6c2fc749e12f12bbee4705e65f217d23fa9072 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Fri, 4 Nov 2022 23:28:13 +0000 Subject: Reorder samplers --- modules/sd_samplers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b28a2e4c..1e88f7ee 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -24,13 +24,13 @@ samplers_k_diffusion = [ ('Heun', 'sample_heun', ['k_heun'], {}), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), + ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), + ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), - ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM-Solver++(2S) a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), ('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ] -- cgit v1.2.1 From ebce0c57c78a3f22178e3a38938d19ec0dfb703d Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sat, 5 Nov 2022 11:38:24 +0800 Subject: Use typing.Optional instead of | to add support for Python 3.9 and below. --- modules/api/models.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 2ae75f43..a44c5ddd 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,6 +1,6 @@ import inspect from pydantic import BaseModel, Field, create_model -from typing import Any, Optional, Union +from typing import Any, Optional from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img @@ -185,22 +185,22 @@ _options = vars(parser)['_option_string_actions'] for key in _options: if(_options[key].dest != 'help'): flag = _options[key] - _type = str - if(_options[key].default != None): _type = type(_options[key].default) + _type = str + if _options[key].default is not None: _type = type(_options[key].default) flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) FlagsModel = create_model("Flags", **flags) class SamplerItem(BaseModel): name: str = Field(title="Name") - aliases: list[str] = Field(title="Aliases") + aliases: list[str] = Field(title="Aliases") options: dict[str, str] = Field(title="Options") class UpscalerItem(BaseModel): name: str = Field(title="Name") - model_name: str | None = Field(title="Model Name") - model_path: str | None = Field(title="Path") - model_url: str | None = Field(title="URL") + model_name: Optional[str] = Field(title="Model Name") + model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") class SDModelItem(BaseModel): title: str = Field(title="Title") @@ -211,21 +211,21 @@ class SDModelItem(BaseModel): class HypernetworkItem(BaseModel): name: str = Field(title="Name") - path: str | None = Field(title="Path") + path: Optional[str] = Field(title="Path") class FaceRestorerItem(BaseModel): name: str = Field(title="Name") - cmd_dir: str | None = Field(title="Path") + cmd_dir: Optional[str] = Field(title="Path") class RealesrganItem(BaseModel): name: str = Field(title="Name") - path: str | None = Field(title="Path") - scale: int | None = Field(title="Scale") + path: Optional[str] = Field(title="Path") + scale: Optional[int] = Field(title="Scale") class PromptStyleItem(BaseModel): name: str = Field(title="Name") - prompt: str | None = Field(title="Prompt") - negative_prompt: str | None = Field(title="Negative Prompt") + prompt: Optional[str] = Field(title="Prompt") + negative_prompt: Optional[str] = Field(title="Negative Prompt") class ArtistItem(BaseModel): name: str = Field(title="Name") -- cgit v1.2.1 From e9a5562b9b27a1a4f9c282637b111cefd9727a41 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 5 Nov 2022 04:06:51 -0500 Subject: add support for tls (gradio tls options) --- modules/shared.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 962115f6..7a20c3af 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,9 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) +parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) cmd_opts = parser.parse_args() restricted_opts = { -- cgit v1.2.1 From 03b08c4a6b0609f24ec789d40100529b92ef0612 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 5 Nov 2022 15:04:48 +0300 Subject: do not die when an extension's repo has no remote --- modules/extensions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index 897af96e..8e0977fd 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -34,8 +34,11 @@ class Extension: if repo is None or repo.bare: self.remote = None else: - self.remote = next(repo.remote().urls, None) - self.status = 'unknown' + try: + self.remote = next(repo.remote().urls, None) + self.status = 'unknown' + except Exception: + self.remote = None def list_files(self, subdir, extension): from modules import scripts -- cgit v1.2.1 From a170e3d22231e145f42bb878a76ae5f76fdca230 Mon Sep 17 00:00:00 2001 From: Evgeniy Date: Sat, 5 Nov 2022 17:06:56 +0300 Subject: Python 3.8 typing compatibility Solves problems with ```Traceback (most recent call last): File "webui.py", line 201, in webui() File "webui.py", line 178, in webui create_api(app) File "webui.py", line 117, in create_api from modules.api.api import Api File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\api.py", line 9, in from modules.api.models import * File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 194, in class SamplerItem(BaseModel): File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 196, in SamplerItem aliases: list[str] = Field(title="Aliases") TypeError: 'type' object is not subscriptable``` and ```Traceback (most recent call last): File "webui.py", line 201, in webui() File "webui.py", line 178, in webui create_api(app) File "webui.py", line 117, in create_api from modules.api.api import Api File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\api.py", line 9, in from modules.api.models import * File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 194, in class SamplerItem(BaseModel): File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 197, in SamplerItem options: dict[str, str] = Field(title="Options") TypeError: 'type' object is not subscriptable``` --- modules/api/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index a44c5ddd..f89da1ff 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -5,7 +5,7 @@ from typing_extensions import Literal from inflection import underscore from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers, opts, parser -from typing import List +from typing import Dict, List API_NOT_ALLOWED = [ "self", @@ -193,8 +193,8 @@ FlagsModel = create_model("Flags", **flags) class SamplerItem(BaseModel): name: str = Field(title="Name") - aliases: list[str] = Field(title="Aliases") - options: dict[str, str] = Field(title="Options") + aliases: List[str] = Field(title="Aliases") + options: Dict[str, str] = Field(title="Options") class UpscalerItem(BaseModel): name: str = Field(title="Name") @@ -230,4 +230,4 @@ class PromptStyleItem(BaseModel): class ArtistItem(BaseModel): name: str = Field(title="Name") score: float = Field(title="Score") - category: str = Field(title="Category") \ No newline at end of file + category: str = Field(title="Category") -- cgit v1.2.1 From 62e3d71aa778928d63cab81d9d8cde33e55bebb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 5 Nov 2022 17:09:42 +0300 Subject: rework the code to not use the walrus operator because colab's 3.7 does not support it --- modules/hypernetworks/hypernetwork.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5ceed6ee..7f182712 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -429,13 +429,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True + # Here we use optimizer from saved HN, or we can specify as UI option. - if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict: + if hypernetwork.optimizer_name in optimizer_dict: optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate) + optimizer_name = hypernetwork.optimizer_name else: - print(f"Optimizer type {optimizer_name} is not defined!") + print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) optimizer_name = 'AdamW' + if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. try: optimizer.load_state_dict(hypernetwork.optimizer_state_dict) -- cgit v1.2.1 From 159475e072f2ed3db8235aab9c3fa18640b93b80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 5 Nov 2022 18:32:22 +0300 Subject: tweak names a bit for new samplers --- modules/sd_samplers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 1e88f7ee..783992d2 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -24,15 +24,15 @@ samplers_k_diffusion = [ ('Heun', 'sample_heun', ['k_heun'], {}), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), - ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), - ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), + ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), - ('DPM-Solver++(2S) a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), - ('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ] samplers_data_k_diffusion = [ -- cgit v1.2.1 From 99b05addb1c98169d78957f13efef308aef0af94 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 18:46:47 -0300 Subject: Fix options endpoint not showing the full list of options --- modules/api/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index f89da1ff..0ea62155 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -168,9 +168,9 @@ class ProgressResponse(BaseModel): current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") fields = {} -for key, value in opts.data.items(): - metadata = opts.data_labels.get(key) - optType = opts.typemap.get(type(value), type(value)) +for key, metadata in opts.data_labels.items(): + value = opts.data.get(key) + optType = opts.typemap.get(type(metadata.default), type(value)) if (metadata is not None): fields.update({key: (Optional[optType], Field( -- cgit v1.2.1 From 0ebf66b575f008a027097946eb2f6845feffd010 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 18:58:19 -0300 Subject: Fix set config endpoint --- modules/api/api.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 112000b8..a924c83a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -230,14 +230,10 @@ class Api: return options - def set_config(self, req: OptionsModel): - # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will - # overwrite all options with default values. - raise RuntimeError('Setting options via API is not supported') - - reqDict = vars(req) - for o in reqDict: - setattr(shared.opts, o, reqDict[o]) + def set_config(self, req: Dict[str, Any]): + + for o in req: + setattr(shared.opts, o, req[o]) shared.opts.save(shared.config_filename) return -- cgit v1.2.1 From 3c72055c22425dcde0739b5246e3501f4a3ec794 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 19:05:15 -0300 Subject: Add skip endpoint --- modules/api/api.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a924c83a..c7ceb787 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -64,6 +64,7 @@ class Api: self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) @@ -219,6 +220,11 @@ class Api: return {} + def skip(self): + shared.state.skip() + + return + def get_config(self): options = {} for key in shared.opts.data.keys(): -- cgit v1.2.1 From 7f63980e479c7ffaec907fb659b5024e96eb72e7 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 5 Nov 2022 19:09:13 -0300 Subject: Remove unnecesary return --- modules/api/api.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index c7ceb787..33e6c6dc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -223,8 +223,6 @@ class Api: def skip(self): shared.state.skip() - return - def get_config(self): options = {} for key in shared.opts.data.keys(): -- cgit v1.2.1 From 6603f63b7b8af39ab815091460c5c2a12d3f253e Mon Sep 17 00:00:00 2001 From: Han Lin Date: Sun, 6 Nov 2022 11:08:20 +0800 Subject: Fixes LDSR upscaler producing black bars --- modules/ldsr_model_arch.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py index 14db5076..90e0a2f0 100644 --- a/modules/ldsr_model_arch.py +++ b/modules/ldsr_model_arch.py @@ -101,8 +101,8 @@ class LDSR: down_sample_rate = target_scale / 4 wd = width_og * down_sample_rate hd = height_og * down_sample_rate - width_downsampled_pre = int(wd) - height_downsampled_pre = int(hd) + width_downsampled_pre = int(np.ceil(wd)) + height_downsampled_pre = int(np.ceil(hd)) if down_sample_rate != 1: print( @@ -110,7 +110,12 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - logs = self.run(model["model"], im_og, diffusion_steps, eta) + + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts + pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size + im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) + + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] sample = sample.detach().cpu() @@ -120,6 +125,9 @@ class LDSR: sample = np.transpose(sample, (0, 2, 3, 1)) a = Image.fromarray(sample[0]) + # remove padding + a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) + del model gc.collect() torch.cuda.empty_cache() -- cgit v1.2.1 From a2a1a2f7270a865175f64475229838a8d64509ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 09:02:25 +0300 Subject: add ability to create extensions that add localizations --- modules/localization.py | 6 ++++++ modules/scripts.py | 1 - modules/shared.py | 2 -- modules/ui.py | 3 +-- 4 files changed, 7 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/localization.py b/modules/localization.py index b1810cda..f6a6f2fb 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -3,6 +3,7 @@ import os import sys import traceback + localizations = {} @@ -16,6 +17,11 @@ def list_localizations(dirname): localizations[fn] = os.path.join(dirname, file) + from modules import scripts + for file in scripts.list_scripts("localizations", ".json"): + fn, ext = os.path.splitext(file.filename) + localizations[fn] = file.path + def localization_js(current_localization_name): fn = localizations.get(current_localization_name, None) diff --git a/modules/scripts.py b/modules/scripts.py index 366c90d7..637b2329 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,7 +3,6 @@ import sys import traceback from collections import namedtuple -import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing diff --git a/modules/shared.py b/modules/shared.py index 70b998ff..e8bacd3c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -localization.list_localizations(cmd_opts.localizations_dir) - def realesrgan_models_names(): import modules.realesrgan_model diff --git a/modules/ui.py b/modules/ui.py index 76ca9b07..23643c22 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,11 +1563,10 @@ def create_ui(wrap_gradio_gpu_call): shared.state.need_restart = True restart_gradio.click( - fn=request_restart, + _js='restart_reload', inputs=[], outputs=[], - _js='restart_reload' ) if column is not None: -- cgit v1.2.1 From e5b4e3f820cd09e751f1d168ab05d606d078a0d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 10:12:53 +0300 Subject: add tags to extensions, and ability to filter out tags list changed Settings keys in UI do not print VRAM/etc stats everywhere but in calls that use GPU --- modules/ui.py | 25 ++++++++++++---------- modules/ui_extensions.py | 55 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 58 insertions(+), 22 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 23643c22..c946ad59 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None): gr.processing_utils.save_pil_to_file = save_pil_to_file -def wrap_gradio_call(func, extra_outputs=None): +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() @@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None): res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_s:.2f}s" - if (elapsed_m > 0): + if elapsed_m > 0: elapsed_text = f"{elapsed_m}m "+elapsed_text if run_memmon: @@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - return tuple(res) return f @@ -1436,7 +1439,7 @@ def create_ui(wrap_gradio_gpu_call): opts.reorder() def run_settings(*args): - changed = 0 + changed = [] for key, value, comp in zip(opts.data_labels.keys(), args, components): assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" @@ -1454,12 +1457,12 @@ def create_ui(wrap_gradio_gpu_call): if opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - changed += 1 + changed.append(key) try: opts.save(shared.config_filename) except RuntimeError: - return opts.dumpjson(), f'{changed} settings changed without save.' - return opts.dumpjson(), f'{changed} settings changed.' + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 8e0d41d5..02ab9643 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url): +def install_extension_from_index(url, hide_tags): ext_table, message = install_extension_from_url(None, url) - return refresh_available_extensions_from_data(), ext_table, message + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, ext_table, message -def refresh_available_extensions(url): + +def refresh_available_extensions(url, hide_tags): global available_extensions import urllib.request @@ -155,13 +157,25 @@ def refresh_available_extensions(url): available_extensions = json.loads(text) - return url, refresh_available_extensions_from_data(), '' + code, tags = refresh_available_extensions_from_data(hide_tags) + + return url, code, gr.CheckboxGroup.update(choices=tags), '' + + +def refresh_available_extensions_for_tags(hide_tags): + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, '' -def refresh_available_extensions_from_data(): + +def refresh_available_extensions_from_data(hide_tags): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + tags = available_extensions.get("tags", {}) + tags_to_hide = set(hide_tags) + hidden = 0 + code = f""" @@ -178,17 +192,24 @@ def refresh_available_extensions_from_data(): name = ext.get("name", "noname") url = ext.get("url", None) description = ext.get("description", "") + extension_tags = ext.get("tags", []) if url is None: continue + if len([x for x in extension_tags if x in tags_to_hide]) > 0: + hidden += 1 + continue + existing = installed_extension_urls.get(normalize_git_url(url), None) install_code = f"""""" + tags_text = ", ".join([f"{x}" for x in extension_tags]) + code += f""" - + @@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
{html.escape(name)}{html.escape(name)}
{tags_text}
{html.escape(description)} {install_code}
""" - return code + if hidden > 0: + code += f"

Extension hidden: {hidden}

" + + return code, list(tags) def create_ui(): @@ -238,21 +262,30 @@ def create_ui(): extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + with gr.Row(): + hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"]) + install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]), - inputs=[available_extensions_index], - outputs=[available_extensions_index, available_extensions_table, install_result], + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[available_extensions_index, hide_tags], + outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install], + inputs=[extension_to_install, hide_tags], outputs=[available_extensions_table, extensions_table, install_result], ) + hide_tags.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags], + outputs=[available_extensions_table, install_result] + ) + with gr.TabItem("Install from URL"): install_url = gr.Text(label="URL for extension's git repository") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") -- cgit v1.2.1 From 6e4de5b4422dfc0d45063b2c8c78b19f00321615 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 11:20:23 +0300 Subject: add load_with_extra function for modules to load checkpoints with extended whitelist --- modules/safe.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index 348a24fc..a9209e38 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -23,11 +23,18 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler): + extra_handler = None + def persistent_load(self, saved_id): assert saved_id[0] == 'storage' return TypedStorage() def find_class(self, module, name): + if self.extra_handler is not None: + res = self.extra_handler(module, name) + if res is not None: + return res + if module == 'collections' and name == 'OrderedDict': return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: @@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler): return set # Forbid everything else. - raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") + raise Exception(f"global '{module}/{name}' is forbidden") allowed_zip_names = ["archive/data.pkl", "archive/version"] @@ -69,7 +76,7 @@ def check_zip_filenames(filename, names): raise Exception(f"bad file inside {filename}: {name}") -def check_pt(filename): +def check_pt(filename, extra_handler): try: # new pytorch format is a zip file @@ -78,6 +85,7 @@ def check_pt(filename): with z.open('archive/data.pkl') as file: unpickler = RestrictedUnpickler(file) + unpickler.extra_handler = extra_handler unpickler.load() except zipfile.BadZipfile: @@ -85,16 +93,42 @@ def check_pt(filename): # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle with open(filename, "rb") as file: unpickler = RestrictedUnpickler(file) + unpickler.extra_handler = extra_handler for i in range(5): unpickler.load() def load(filename, *args, **kwargs): + return load_with_extra(filename, *args, **kwargs) + + +def load_with_extra(filename, extra_handler=None, *args, **kwargs): + """ + this functon is intended to be used by extensions that want to load models with + some extra classes in them that the usual unpickler would find suspicious. + + Use the extra_handler argument to specify a function that takes module and field name as text, + and returns that field's value: + + ```python + def extra(module, name): + if module == 'collections' and name == 'OrderedDict': + return collections.OrderedDict + + return None + + safe.load_with_extra('model.pt', extra_handler=extra) + ``` + + The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is + definitely unsafe. + """ + from modules import shared try: if not shared.cmd_opts.disable_safe_unpickle: - check_pt(filename) + check_pt(filename, extra_handler) except pickle.UnpicklingError: print(f"Error verifying pickled file from {filename}:", file=sys.stderr) -- cgit v1.2.1 From 55ca04095845b41bf66333b3b7343e3ea0babed1 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sun, 6 Nov 2022 16:31:44 +0800 Subject: Resolve conflict --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 86d015af..db35983b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -422,14 +422,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - opts.data[k] = v # we don't call onchange for simplicity which makes changing model impossible + setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not res = process_images_inner(p) finally: # restore opts to original state for k, v in stored_opts.items(): - opts.data[k] = v + setattr(opts, k, v) if k == 'sd_hypernetwork': shared.reload_hypernetworks() return res -- cgit v1.2.1 From 32c0eab89538ba3900bf499291720f80ae4b43e5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 14:39:41 +0300 Subject: load all settings in one call instead of one by one when the page loads --- modules/ui.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index c946ad59..34c31ef1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1141,7 +1141,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[html, generation_info, html2], ) - with gr.Blocks() as modelmerger_interface: + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") @@ -1161,7 +1161,7 @@ def create_ui(wrap_gradio_gpu_call): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - with gr.Blocks() as train_interface: + with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") @@ -1420,15 +1420,14 @@ def create_ui(wrap_gradio_gpu_call): if info.refresh is not None: if is_quicksettings: - res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) - + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) return res @@ -1639,6 +1638,17 @@ def create_ui(wrap_gradio_gpu_call): outputs=[component, text_settings], ) + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) -- cgit v1.2.1 From 67c8e11be74180be19341aebbd6a246c37a79fbb Mon Sep 17 00:00:00 2001 From: snowmeow2 Date: Mon, 7 Nov 2022 02:32:06 +0800 Subject: Adding DeepDanbooru to the interrogation API --- modules/api/api.py | 16 ++++++++++++++-- modules/api/models.py | 1 + 2 files changed, 15 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 688469ad..596a6616 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -15,6 +15,9 @@ from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List +if shared.cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags + def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) @@ -220,11 +223,20 @@ class Api: if image_b64 is None: raise HTTPException(status_code=404, detail="Image not found") - img = self.__base64_to_image(image_b64) + img = decode_base64_to_image(image_b64) + img = img.convert('RGB') # Override object param with self.queue_lock: - processed = shared.interrogator.interrogate(img) + if interrogatereq.model == "clip": + processed = shared.interrogator.interrogate(img) + elif interrogatereq.model == "deepdanbooru": + if shared.cmd_opts.deepdanbooru: + processed = get_deepbooru_tags(img) + else: + raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.") + else: + raise HTTPException(status_code=404, detail="Model not found") return InterrogateResponse(caption=processed) diff --git a/modules/api/models.py b/modules/api/models.py index 34dbfa16..f9cd929e 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -170,6 +170,7 @@ class ProgressResponse(BaseModel): class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + model: str = Field(default="clip", title="Model", description="The interrogate model used.") class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") -- cgit v1.2.1 From cd6c55c1ab14fcab15329cde599cf79e8d555657 Mon Sep 17 00:00:00 2001 From: pepe10-gpu Date: Sun, 6 Nov 2022 17:05:51 -0800 Subject: 16xx card fix cudnn --- modules/devices.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 7511e1dc..858bf399 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -39,10 +39,13 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True + errors.run(enable_tf32, "Enabling TF32") device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None -- cgit v1.2.1 From a258fd60dbe2d68325339405a2aa72816d06d2fd Mon Sep 17 00:00:00 2001 From: Keavon Chambers Date: Mon, 7 Nov 2022 00:13:58 -0800 Subject: Add CORS-allow policy launch argument using regex --- modules/shared.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c..55de286d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,12 +81,13 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") -parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) -parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) +parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) -- cgit v1.2.1 From 9ed4a126bd6421f91bf4a9bdd348b6aef0a378c6 Mon Sep 17 00:00:00 2001 From: kavorite Date: Mon, 7 Nov 2022 19:58:49 -0500 Subject: add gradio-inpaint-tool; color-sketch --- modules/img2img.py | 19 +++++++++++++------ modules/shared.py | 1 + modules/ui.py | 11 ++++++++++- 3 files changed, 24 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index be9f3653..00c6f827 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,18 +59,25 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 if is_inpaint: # Drawn mask if mask_mode == 0: - image = init_img_with_mask['image'] - mask = init_img_with_mask['mask'] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') - image = image.convert('RGB') + image = init_img_with_mask + is_mask_sketch = isinstance(image, dict) + if is_mask_sketch: + # Sketch: mask iff. not transparent + image, mask = image["image"], image["mask"] + mask = np.array(mask)[..., -1] > 0 + else: + # Color-sketch: mask iff. painted over + orig = init_img_with_mask_orig or image + mask = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(mask.astype(np.uint8) * 255, "L") + image = image.convert("RGB") # Uploaded mask else: image = init_img_inpaint diff --git a/modules/shared.py b/modules/shared.py index d8e99f85..325e37d9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -71,6 +71,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") +parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 2609857e..db323e9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -840,8 +840,17 @@ def create_ui(wrap_gradio_gpu_call): init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) with gr.TabItem('Inpaint', id='inpaint'): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") -- cgit v1.2.1 From 29eff4a194d22f0f0e7a7a976d746a71a4193cf5 Mon Sep 17 00:00:00 2001 From: pepe10-gpu Date: Mon, 7 Nov 2022 18:06:48 -0800 Subject: terrible hack --- modules/devices.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 858bf399..4c63f465 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -39,8 +39,15 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True - torch.backends.cudnn.enabled = True + #TODO: make this better; find a way to check if it is a turing card + turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"] + for devid in range(0,torch.cuda.device_count()): + for i in turing: + if i in torch.cuda.get_device_name(devid): + shd = True + if shd: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -- cgit v1.2.1 From c5334fc56b3d44976425da2e6d0a303ae96836a1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:35:01 +0300 Subject: fix javascript duplication bug after pressing the restart UI button --- modules/ui.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 34c31ef1..67cf1d6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1752,7 +1752,7 @@ def create_ui(wrap_gradio_gpu_call): return demo -def load_javascript(raw_response): +def reload_javascript(): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: javascript = f'' @@ -1768,7 +1768,7 @@ def load_javascript(raw_response): javascript += f"\n" def template_response(*args, **kwargs): - res = raw_response(*args, **kwargs) + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace( b'', f'{javascript}'.encode("utf8")) res.init_headers() @@ -1777,4 +1777,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response -reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse -- cgit v1.2.1 From 8011be33c36eb7aa9e9498fc714614034e07f67a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:37:05 +0300 Subject: move functions out of main body for image preprocessing for easier hijacking --- modules/textual_inversion/preprocess.py | 162 ++++++++++++++++++-------------- 1 file changed, 93 insertions(+), 69 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index e13b1894..488aa5b5 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -35,6 +35,84 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce deepbooru.release_process() +def listfiles(dirname): + return os.listdir(dirname) + + +class PreprocessParams: + src = None + dstdir = None + subindex = 0 + flip = False + process_caption = False + process_caption_deepbooru = False + preprocess_txt_action = None + + +def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None): + caption = "" + + if params.process_caption: + caption += shared.interrogator.generate_caption(image) + + if params.process_caption_deepbooru: + if len(caption) > 0: + caption += ", " + caption += deepbooru.get_tags_from_process(image) + + filename_part = params.src + filename_part = os.path.splitext(filename_part)[0] + filename_part = os.path.basename(filename_part) + + basename = f"{index:05}-{params.subindex}-{filename_part}" + image.save(os.path.join(params.dstdir, f"{basename}.png")) + + if params.preprocess_txt_action == 'prepend' and existing_caption: + caption = existing_caption + ' ' + caption + elif params.preprocess_txt_action == 'append' and existing_caption: + caption = caption + ' ' + existing_caption + elif params.preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + + if len(caption) > 0: + with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: + file.write(caption) + + params.subindex += 1 + + +def save_pic(image, index, params, existing_caption=None): + save_pic_with_caption(image, index, params, existing_caption=existing_caption) + + if params.flip: + save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption) + + +def split_pic(image, inverse_xy, width, height, overlap_ratio): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): width = process_width @@ -48,82 +126,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre os.makedirs(dst, exist_ok=True) - files = os.listdir(src) + files = listfiles(src) shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - def save_pic_with_caption(image, index, existing_caption=None): - caption = "" - - if process_caption: - caption += shared.interrogator.generate_caption(image) - - if process_caption_deepbooru: - if len(caption) > 0: - caption += ", " - caption += deepbooru.get_tags_from_process(image) - - filename_part = filename - filename_part = os.path.splitext(filename_part)[0] - filename_part = os.path.basename(filename_part) - - basename = f"{index:05}-{subindex[0]}-{filename_part}" - image.save(os.path.join(dst, f"{basename}.png")) - - if preprocess_txt_action == 'prepend' and existing_caption: - caption = existing_caption + ' ' + caption - elif preprocess_txt_action == 'append' and existing_caption: - caption = caption + ' ' + existing_caption - elif preprocess_txt_action == 'copy' and existing_caption: - caption = existing_caption - - caption = caption.strip() - - if len(caption) > 0: - with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: - file.write(caption) - - subindex[0] += 1 - - def save_pic(image, index, existing_caption=None): - save_pic_with_caption(image, index, existing_caption=existing_caption) - - if process_flip: - save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) - - def split_pic(image, inverse_xy): - if inverse_xy: - from_w, from_h = image.height, image.width - to_w, to_h = height, width - else: - from_w, from_h = image.width, image.height - to_w, to_h = width, height - h = from_h * to_w // from_w - if inverse_xy: - image = image.resize((h, to_w)) - else: - image = image.resize((to_w, h)) - - split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) - y_step = (h - to_h) / (split_count - 1) - for i in range(split_count): - y = int(y_step * i) - if inverse_xy: - splitted = image.crop((y, 0, y + to_h, to_w)) - else: - splitted = image.crop((0, y, to_w, y + to_h)) - yield splitted - + params = PreprocessParams() + params.dstdir = dst + params.flip = process_flip + params.process_caption = process_caption + params.process_caption_deepbooru = process_caption_deepbooru + params.preprocess_txt_action = preprocess_txt_action for index, imagefile in enumerate(tqdm.tqdm(files)): - subindex = [0] + params.subindex = 0 filename = os.path.join(src, imagefile) try: img = Image.open(filename).convert("RGB") except Exception: continue + params.src = filename + existing_caption = None existing_caption_filename = os.path.splitext(filename)[0] + '.txt' if os.path.exists(existing_caption_filename): @@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre process_default_resize = True if process_split and ratio < 1.0 and ratio <= split_threshold: - for splitted in split_pic(img, inverse_xy): - save_pic(splitted, index, existing_caption=existing_caption) + for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio): + save_pic(splitted, index, params, existing_caption=existing_caption) process_default_resize = False if process_focal_crop and img.height != img.width: @@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = dnn_model_path, ) for focal in autocrop.crop_image(img, autocrop_settings): - save_pic(focal, index, existing_caption=existing_caption) + save_pic(focal, index, params, existing_caption=existing_caption) process_default_resize = False if process_default_resize: img = images.resize_image(1, img, width, height) - save_pic(img, index, existing_caption=existing_caption) + save_pic(img, index, params, existing_caption=existing_caption) - shared.state.nextjob() \ No newline at end of file + shared.state.nextjob() -- cgit v1.2.1 From 1610b3258458025025e9c4faae57d290e4519745 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:38:10 +0300 Subject: add callback for creating a tab in train UI --- modules/script_callbacks.py | 27 +++++++++++++++++++++++++-- modules/ui.py | 4 ++++ 2 files changed, 29 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 74dfb880..f19e164c 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -7,6 +7,7 @@ from typing import Optional from fastapi import FastAPI from gradio import Blocks + def report_exception(c, job): print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -45,15 +46,21 @@ class CFGDenoiserParams: """Total number of sampling steps planned""" +class UiTrainTabParams: + def __init__(self, txt2img_preview_params): + self.txt2img_preview_params = txt2img_preview_params + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], callbacks_ui_tabs=[], + callbacks_ui_train_tabs=[], callbacks_ui_settings=[], callbacks_before_image_saved=[], callbacks_image_saved=[], - callbacks_cfg_denoiser=[] + callbacks_cfg_denoiser=[], ) @@ -61,6 +68,7 @@ def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() + def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: @@ -79,7 +87,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - + for c in callback_map['callbacks_ui_tabs']: try: res += c.callback() or [] @@ -89,6 +97,14 @@ def ui_tabs_callback(): return res +def ui_train_tabs_callback(params: UiTrainTabParams): + for c in callback_map['callbacks_ui_train_tabs']: + try: + c.callback(params) + except Exception: + report_exception(c, 'callbacks_ui_train_tabs') + + def ui_settings_callback(): for c in callback_map['callbacks_ui_settings']: try: @@ -169,6 +185,13 @@ def on_ui_tabs(callback): add_callback(callback_map['callbacks_ui_tabs'], callback) +def on_ui_train_tabs(callback): + """register a function to be called when the UI is creating new tabs for the train tab. + Create your new tabs with gr.Tab. + """ + add_callback(callback_map['callbacks_ui_train_tabs'], callback) + + def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ diff --git a/modules/ui.py b/modules/ui.py index 67cf1d6a..7ea1177f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call): train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary') + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + with gr.Column(): progressbar = gr.HTML(elem_id="ti_progressbar") ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) -- cgit v1.2.1 From c34542a48376e4972de955aab00ffc8359f7d792 Mon Sep 17 00:00:00 2001 From: kavorite Date: Tue, 8 Nov 2022 03:25:59 -0500 Subject: add new color-sketch state to img2img invocation --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index db323e9c..29954f2a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -941,6 +941,7 @@ def create_ui(wrap_gradio_gpu_call): img2img_prompt_style2, init_img, init_img_with_mask, + init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, -- cgit v1.2.1 From cfcadeae9a61e1aff32960864f90299412c86d5c Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 8 Nov 2022 10:03:56 -0600 Subject: Add option to preload extensions By creating a file called "preload.py" in an extension folder and declaring a preload(parser) method, we can add extra command-line args for an extension. --- modules/extensions.py | 23 ++++++++++++++++++++++- modules/shared.py | 5 ++++- 2 files changed, 26 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index 8e0977fd..544f3580 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,12 +1,12 @@ import os import sys import traceback +from importlib.machinery import SourceFileLoader import git from modules import paths, shared - extensions = [] extensions_dir = os.path.join(paths.script_path, "extensions") @@ -84,3 +84,24 @@ def list_extensions(): extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extensions.append(extension) + + +def preload_extensions(parser): + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + path = os.path.join(extensions_dir, dirname) + if not os.path.isdir(path): + continue + for file in os.listdir(path): + if "preload.py" in file: + full_file = os.path.join(path, file) + print(f"Got preload file: {full_file}") + + try: + ext = SourceFileLoader("preload", full_file).load_module() + parser = ext.preload(parser) + except Exception as e: + print(f"Exception preloading script: {e}") + return parser \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c..222ad4fb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -15,7 +15,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae +from modules import sd_samplers, sd_models, localization, sd_vae, extensions from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -91,7 +91,10 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +extensions.preload_extensions(parser) + cmd_opts = parser.parse_args() + restricted_opts = { "samples_filename_pattern", "directories_filename_pattern", -- cgit v1.2.1 From 62e9fec3df8518da3a2c35fa090bb54946c856b2 Mon Sep 17 00:00:00 2001 From: pepe10-gpu Date: Tue, 8 Nov 2022 15:19:09 -0800 Subject: actual better fix thanks C43H66N12O12S2 --- modules/devices.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 4c63f465..058a5e00 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -39,12 +39,9 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): - #TODO: make this better; find a way to check if it is a turing card - turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"] for devid in range(0,torch.cuda.device_count()): - for i in turing: - if i in torch.cuda.get_device_name(devid): - shd = True + if torch.cuda.get_device_capability(devid) == (7, 5): + shd = True if shd: torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True -- cgit v1.2.1 From 59bb1d36ea69db449cfe23be4988ab4f6711bf4b Mon Sep 17 00:00:00 2001 From: kavorite Date: Tue, 8 Nov 2022 22:06:29 -0500 Subject: blur mask with color-sketch + add paint transparency slider --- modules/img2img.py | 21 +++++++++++++-------- modules/ui.py | 3 +++ 2 files changed, 16 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 00c6f827..644297da 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -4,7 +4,7 @@ import sys import traceback import numpy as np -from PIL import Image, ImageOps, ImageChops +from PIL import Image, ImageOps, ImageFilter, ImageEnhance from modules import devices from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images @@ -40,7 +40,7 @@ def process_batch(p, input_dir, output_dir, args): img = Image.open(image) # Use the EXIF orientation of photos taken by smartphones. - img = ImageOps.exif_transpose(img) + img = ImageOps.exif_transpose(img) p.init_images = [img] * p.batch_size proc = modules.scripts.scripts_img2img.run(p, *args) @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -68,15 +68,20 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if mask_mode == 0: image = init_img_with_mask is_mask_sketch = isinstance(image, dict) - if is_mask_sketch: + is_mask_paint = not is_mask_sketch + if is_mask_sketch: # Sketch: mask iff. not transparent image, mask = image["image"], image["mask"] - mask = np.array(mask)[..., -1] > 0 + pred = np.array(mask)[..., -1] > 0 else: # Color-sketch: mask iff. painted over orig = init_img_with_mask_orig or image - mask = np.any(np.array(image) != np.array(orig), axis=-1) - mask = Image.fromarray(mask.astype(np.uint8) * 255, "L") + pred = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") + if is_mask_paint: + mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) + blur = ImageFilter.GaussianBlur(mask_blur) + image = Image.composite(image.filter(blur), orig, mask.filter(blur)) image = image.convert("RGB") # Uploaded mask else: @@ -89,7 +94,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro # Use the EXIF orientation of photos taken by smartphones. if image is not None: - image = ImageOps.exif_transpose(image) + image = ImageOps.exif_transpose(image) assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' diff --git a/modules/ui.py b/modules/ui.py index 29954f2a..16982abf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -854,6 +854,8 @@ def create_ui(wrap_gradio_gpu_call): init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + show_mask_alpha = cmd_opts.gradio_inpaint_tool == "color-sketch" + mask_alpha = gr.Slider(label="Mask transparency", interactive=show_mask_alpha, visible=show_mask_alpha) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) with gr.Row(): @@ -948,6 +950,7 @@ def create_ui(wrap_gradio_gpu_call): steps, sampler_index, mask_blur, + mask_alpha, inpainting_fill, restore_faces, tiling, -- cgit v1.2.1 From 3b51d239ac9201228c6032fc109111e347e8e6b0 Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 04:54:21 +0100 Subject: - do not use ckpt cache, if disabled - cache model after is has been loaded from file --- modules/sd_models.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 34c57bfa..720c2a96 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + cache_enabled = shared.opts.sd_checkpoint_cache > 0 + + if cache_enabled: sd_vae.restore_base_vae(model) - checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if checkpoint_info not in checkpoints_loaded: + if cache_enabled and checkpoint_info in checkpoints_loaded: + # use checkpoint cache + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + model.load_state_dict(checkpoints_loaded[checkpoint_info]) + else: + # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): del pl_sd model.load_state_dict(sd, strict=False) del sd + + if cache_enabled: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -199,13 +211,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.first_stage_model.to(devices.dtype_vae) - else: - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - - if shared.opts.sd_checkpoint_cache > 0: + # clean up cache if limit is reached + if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU -- cgit v1.2.1 From eebf49592ad2c0933e58b06a098b92e48d47e4fe Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 07:17:09 +0100 Subject: restore #4035 behavior - if checkpoint cache is set to 1, keep 2 models in cache (current +1 more) --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 720c2a96..80addf03 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -213,7 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # clean up cache if limit is reached if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash -- cgit v1.2.1 From 81f2575df91a50e4aa9ca816e02e3f77342eedc8 Mon Sep 17 00:00:00 2001 From: Liam Date: Wed, 9 Nov 2022 15:24:31 -0500 Subject: updating the displayed generation info when user clicks images in the gallery. feature request 4415 --- modules/ui.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 7ea1177f..756499d1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -566,6 +566,17 @@ def apply_setting(key, value): return value +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -638,6 +649,15 @@ Requested path was: {f} with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) save.click( fn=wrap_gradio_call(save_files), -- cgit v1.2.1 From 893191cab24cc3511135495d6d2c8d81f5ec63a3 Mon Sep 17 00:00:00 2001 From: Tong Zeng Date: Thu, 10 Nov 2022 10:34:03 +0800 Subject: fix a bug in list_files_with_name --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 637b2329..22d8908b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -140,7 +140,7 @@ def list_files_with_name(filename): continue path = os.path.join(dirpath, filename) - if os.path.isfile(filename): + if os.path.isfile(path): res.append(path) return res -- cgit v1.2.1 From 2505f39e28177452a92426f3b60d8edbe6ed1b14 Mon Sep 17 00:00:00 2001 From: JingShing Date: Thu, 10 Nov 2022 20:39:20 +0800 Subject: Add username and password in ngrok. --- modules/ngrok.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ngrok.py b/modules/ngrok.py index 5c5f349a..e506accb 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -1,14 +1,22 @@ from pyngrok import ngrok, conf, exception - def connect(token, port, region): if token == None: token = 'None' + else: + if ':' in token: + # token = authtoken:username:password + account = token.split(':')[1] + ':' + token.split(':')[-1] + token = token.split(':')[0] + config = conf.PyngrokConfig( auth_token=token, region=region ) try: - public_url = ngrok.connect(port, pyngrok_config=config).public_url + if account: + public_url = ngrok.connect(port, pyngrok_config=config, auth=account).public_url + else: + public_url = ngrok.connect(port, pyngrok_config=config).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') -- cgit v1.2.1 From 1a01191e27545e9dae5255d59c920b6da5b236f4 Mon Sep 17 00:00:00 2001 From: JingShing Date: Thu, 10 Nov 2022 20:42:41 +0800 Subject: Add username and password in ngrok. --- modules/ngrok.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ngrok.py b/modules/ngrok.py index e506accb..10d2179f 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -1,6 +1,7 @@ from pyngrok import ngrok, conf, exception def connect(token, port, region): + account = None if token == None: token = 'None' else: @@ -13,10 +14,10 @@ def connect(token, port, region): auth_token=token, region=region ) try: - if account: - public_url = ngrok.connect(port, pyngrok_config=config, auth=account).public_url - else: + if account == None: public_url = ngrok.connect(port, pyngrok_config=config).public_url + else: + public_url = ngrok.connect(port, pyngrok_config=config, auth=account).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') -- cgit v1.2.1 From b98740129c435f04a060369bd071fc4bafe021f5 Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 10 Nov 2022 13:07:41 -0500 Subject: added event listener for the image gallery modal; moved js to separate file --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 756499d1..5dce7f3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -570,6 +570,8 @@ def update_generation_info(args): generation_info, html_info, img_index = args try: generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info return plaintext_to_html(generation_info["infotexts"][img_index]) except Exception: pass -- cgit v1.2.1 From 6f8a807fe4eb41f6eb355c80fe96cd60b8e8a5a9 Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 09:22:49 +0900 Subject: Update shared.py --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 89f4d5ee..82da5ce0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -321,7 +321,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "shuffle_tags": OptionInfo(False, "Shuffleing tags by "," when create texts."), + "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), -- cgit v1.2.1 From 13a2f1dca32980339e1fb4d1995cde428db798c5 Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:29:55 +0900 Subject: adding tag drop out option --- modules/textual_inversion/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index df278dc2..a95c7835 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -98,12 +98,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) + tags = filename_text.split(',') + if shared.opt.tag_drop_out != 0: + tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] if shared.opts.shuffle_tags: - tags = filename_text.split(',') random.shuffle(tags) - text = text.replace("[filewords]", ','.join(tags)) - else: - text = text.replace("[filewords]", filename_text) + text = text.replace("[filewords]", ','.join(tags)) return text def __len__(self): -- cgit v1.2.1 From 0959907f87314cbee8a80036ec8ae24c65888f7f Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:31:14 +0900 Subject: adding tag dropout option --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 82da5ce0..f2ea3baa 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -322,6 +322,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), + "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), -- cgit v1.2.1 From b19af67d29356f97fea5cccfdfa12583f605243f Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:54:19 +0900 Subject: Update dataset.py --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index a95c7835..e2cb8428 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -99,7 +99,7 @@ class PersonalizedBase(Dataset): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') - if shared.opt.tag_drop_out != 0: + if shared.opts.tag_drop_out != 0: tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] if shared.opts.shuffle_tags: random.shuffle(tags) -- cgit v1.2.1 From c556d34523e8764bd66bf6a7bf97d06add420020 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Fri, 11 Nov 2022 08:54:51 +0700 Subject: Forcing HTTPS instead of HTTP for ngrok For security reason. --- modules/ngrok.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ngrok.py b/modules/ngrok.py index 5c5f349a..25c53af8 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -8,7 +8,7 @@ def connect(token, port, region): auth_token=token, region=region ) try: - public_url = ngrok.connect(port, pyngrok_config=config).public_url + public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') -- cgit v1.2.1 From a1e271207dfc3e89b1286ba41d96b459f210c4b2 Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:56:53 +0900 Subject: Update dataset.py --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e2cb8428..eb75c376 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -100,7 +100,7 @@ class PersonalizedBase(Dataset): text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') if shared.opts.tag_drop_out != 0: - tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] + tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] if shared.opts.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) -- cgit v1.2.1 From 7ba3923d5b494b7756d0b12f33acb3716d830b9a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 11 Nov 2022 18:20:18 +0300 Subject: move DDIM/PLMS fix for OSX out of the file with inpainting code. --- modules/sd_hijack.py | 23 +++++++++++++++++++++++ modules/sd_hijack_inpainting.py | 18 +----------------- 2 files changed, 24 insertions(+), 17 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bc49d235..75b2d22d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,6 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -406,3 +408,24 @@ def add_circular_option_to_conv_2d(): model_hijack = StableDiffusionModelHijack() + + +def register_buffer(self, name, attr): + """ + Fix register buffer bug for Mac OS. + """ + + if type(attr) == torch.Tensor: + if attr.device != devices.device: + + # would this not break cuda when torch adds has_mps() to main version? + if getattr(torch, 'has_mps', False): + attr = attr.to(device="mps", dtype=torch.float32) + else: + attr = attr.to(devices.device) + + setattr(self, name, attr) + + +ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer +ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 202b42cf..46714a4f 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,5 +1,4 @@ import torch -import modules.devices as devices from einops import repeat from omegaconf import ListConfig @@ -315,20 +314,6 @@ class LatentInpaintDiffusion(LatentDiffusion): self.masked_image_key = masked_image_key assert self.masked_image_key in concat_keys self.concat_keys = concat_keys - - -# ================================================================================================= -# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here -# ================================================================================================= -def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - optimal_type = devices.get_optimal_device() - if attr.device != optimal_type: - if getattr(torch, 'has_mps', False): - attr = attr.to(device="mps", dtype=torch.float32) - else: - attr = attr.to(optimal_type) - setattr(self, name, attr) def should_hijack_inpainting(checkpoint_info): @@ -341,8 +326,7 @@ def do_inpainting_hijack(): ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim - ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms - ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer + -- cgit v1.2.1 From 76ab31e18898d4c2aacb9725cfbe25b230bff974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 12 Nov 2022 11:02:40 +0800 Subject: Fix wrong mps selection below MasOS 12.3 --- modules/devices.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 7511e1dc..9a3d29d7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,8 +3,15 @@ import contextlib import torch from modules import errors -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility -has_mps = getattr(torch, 'has_mps', False) +# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# check `getattr` and try it for compatibility +def has_mps() -> bool: + if getattr(torch, 'has_mps', False): return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False cpu = torch.device("cpu") @@ -25,7 +32,7 @@ def get_optimal_device(): else: return torch.device("cuda") - if has_mps: + if has_mps(): return torch.device("mps") return cpu -- cgit v1.2.1 From 1130d5df669911a5c67696be90bccca3ecf5f487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 12 Nov 2022 11:09:28 +0800 Subject: Update devices.py --- modules/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 9a3d29d7..bd3e4ffb 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -6,7 +6,7 @@ from modules import errors # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: - if getattr(torch, 'has_mps', False): return False + if not getattr(torch, 'has_mps', False): return False try: torch.zeros(1).to(torch.device("mps")) return True -- cgit v1.2.1 From c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:00:22 +0300 Subject: use the new devices.has_mps() function in register_buffer for DDIM/PLMS fix for OSX --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 75b2d22d..97979d05 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -418,8 +418,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != devices.device: - # would this not break cuda when torch adds has_mps() to main version? - if getattr(torch, 'has_mps', False): + if devices.has_mps(): attr = attr.to(device="mps", dtype=torch.float32) else: attr = attr.to(devices.device) -- cgit v1.2.1 From 0ab0a50f9ae14bd7ce7ec518323ebd31c7971155 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:00:49 +0300 Subject: change formatting to match the main program in devices.py --- modules/devices.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index bd3e4ffb..67165bf6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,23 +3,27 @@ import contextlib import torch from modules import errors + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: - if not getattr(torch, 'has_mps', False): return False + if not getattr(torch, 'has_mps', False): + return False try: torch.zeros(1).to(torch.device("mps")) return True except Exception: return False -cpu = torch.device("cpu") def extract_device_id(args, name): for x in range(len(args)): - if name in args[x]: return args[x+1] + if name in args[x]: + return args[x + 1] + return None + def get_optimal_device(): if torch.cuda.is_available(): from modules import shared @@ -52,10 +56,12 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") +cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 + def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': @@ -89,6 +95,11 @@ def autocast(disable=False): return torch.autocast("cuda") + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor -def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) +def mps_contiguous(input_tensor, device): + return input_tensor.contiguous() if device.type == 'mps' else input_tensor + + +def mps_contiguous_to(input_tensor, device): + return mps_contiguous(input_tensor, device).to(device) -- cgit v1.2.1 From a1a376331c9ecbbee77b86daeaba44587cc56557 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:56:06 +0300 Subject: make existing script loading and new preload code use same code for loading modules limit extension preload scripts to just one file named preload.py --- modules/extensions.py | 21 --------------------- modules/script_loading.py | 34 ++++++++++++++++++++++++++++++++++ modules/scripts.py | 46 +++++++++++++++++----------------------------- modules/shared.py | 5 ++--- 4 files changed, 53 insertions(+), 53 deletions(-) create mode 100644 modules/script_loading.py (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index 544f3580..94ce479a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,6 @@ import os import sys import traceback -from importlib.machinery import SourceFileLoader import git @@ -85,23 +84,3 @@ def list_extensions(): extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extensions.append(extension) - -def preload_extensions(parser): - if not os.path.isdir(extensions_dir): - return - - for dirname in sorted(os.listdir(extensions_dir)): - path = os.path.join(extensions_dir, dirname) - if not os.path.isdir(path): - continue - for file in os.listdir(path): - if "preload.py" in file: - full_file = os.path.join(path, file) - print(f"Got preload file: {full_file}") - - try: - ext = SourceFileLoader("preload", full_file).load_module() - parser = ext.preload(parser) - except Exception as e: - print(f"Exception preloading script: {e}") - return parser \ No newline at end of file diff --git a/modules/script_loading.py b/modules/script_loading.py new file mode 100644 index 00000000..f93f0951 --- /dev/null +++ b/modules/script_loading.py @@ -0,0 +1,34 @@ +import os +import sys +import traceback +from types import ModuleType + + +def load_module(path): + with open(path, "r", encoding="utf8") as file: + text = file.read() + + compiled = compile(text, path, 'exec') + module = ModuleType(os.path.basename(path)) + exec(compiled, module.__dict__) + + return module + + +def preload_extensions(extensions_dir, parser): + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + preload_script = os.path.join(extensions_dir, dirname, "preload.py") + if not os.path.isfile(preload_script): + continue + + try: + module = load_module(preload_script) + if hasattr(module, 'preload'): + module.preload(parser) + + except Exception: + print(f"Error running preload() for {preload_script}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/scripts.py b/modules/scripts.py index 22d8908b..986b1914 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,7 +6,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions +from modules import shared, paths, script_callbacks, extensions, script_loading AlwaysVisible = object() @@ -161,13 +161,7 @@ def load_scripts(): sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - with open(scriptfile.path, "r", encoding="utf8") as file: - text = file.read() - - from types import ModuleType - compiled = compile(text, scriptfile.path, 'exec') - module = ModuleType(scriptfile.filename) - exec(compiled, module.__dict__) + module = script_loading.load_module(scriptfile.path) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -328,27 +322,21 @@ class ScriptRunner: def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): - with open(script.filename, "r", encoding="utf8") as file: - args_from = script.args_from - args_to = script.args_to - filename = script.filename - text = file.read() - - from types import ModuleType - - module = cache.get(filename, None) - if module is None: - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) - cache[filename] = module - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - self.scripts[si] = script_class() - self.scripts[si].filename = filename - self.scripts[si].args_from = args_from - self.scripts[si].args_to = args_to + args_from = script.args_from + args_to = script.args_to + filename = script.filename + + module = cache.get(filename, None) + if module is None: + module = script_loading.load_module(script.filename) + cache[filename] = module + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() diff --git a/modules/shared.py b/modules/shared.py index 17132e42..6936cbe0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,7 +3,6 @@ import datetime import json import os import sys -from collections import OrderedDict import time import gradio as gr @@ -15,7 +14,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae, extensions +from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) -extensions.preload_extensions(parser) +script_loading.preload_extensions(extensions.extensions_dir, parser) cmd_opts = parser.parse_args() -- cgit v1.2.1 From 98947d173e3f1667eba29c904f681047dea9de90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 11:11:47 +0300 Subject: run installers for newly installed extensions --- modules/ui_extensions.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 02ab9643..6671cb60 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -134,6 +134,9 @@ def install_extension_from_url(dirname, url): os.rename(tmpdir, target_dir) + import launch + launch.run_extension_installer(target_dir) + extensions.list_extensions() return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] finally: -- cgit v1.2.1 From f4a488f585c09b420dc05199240e68f8fb74337f Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 20:12:31 -0500 Subject: Set device for facelib/facexlib and gfpgan * FaceXLib/FaceLib doesn't pass the device argument to RetinaFace but instead chooses one itself and sets it to a global - in order to use a device other than its internally chosen default it is necessary to manually replace the default value * The GFPGAN constructor needs the device argument to work with MPS or a CUDA device ID that differs from the default --- modules/codeformer_model.py | 3 +++ modules/gfpgan_model.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index e6d9fa4f..ab40d842 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -36,6 +36,7 @@ def setup_model(dirname): from basicsr.utils.download_util import load_file_from_url from basicsr.utils import imwrite, img2tensor, tensor2img from facelib.utils.face_restoration_helper import FaceRestoreHelper + from facelib.detection.retinaface import retinaface from modules.shared import cmd_opts net_class = CodeFormer @@ -65,6 +66,8 @@ def setup_model(dirname): net.load_state_dict(checkpoint) net.eval() + if hasattr(retinaface, 'device'): + retinaface.device = devices.device_codeformer face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) self.net = net diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a9452dce..1e2dbc32 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -36,7 +36,9 @@ def gfpgann(): else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) + if hasattr(facexlib.detection.retinaface, 'device'): + facexlib.detection.retinaface.device = devices.device_gfpgan + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) loaded_gfpgan_model = model return model -- cgit v1.2.1 From d671d1d45dfab61292ed788fd7778a33a82212ee Mon Sep 17 00:00:00 2001 From: Mrau Hu Date: Sat, 12 Nov 2022 21:44:42 +0300 Subject: Fix: `error: Your local changes to the following files would be overwritten by merge` when run `pull()` method, because WSL2 Docker set 755 file permissions instead of 644, this results to the error. Updated `Extension` class: replaced `pull()` with `fetch_and_reset_hard()` method. Updated `apply_and_restart()` function: replaced `ext.pull()` with `ext.fetch_and_reset_hard()` function. --- modules/extensions.py | 7 +++++-- modules/ui_extensions.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index 94ce479a..db9c4200 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -65,9 +65,12 @@ class Extension: self.can_update = False self.status = "latest" - def pull(self): + def fetch_and_reset_hard(self): repo = git.Repo(self.path) - repo.remotes.origin.pull() + # Fix: `error: Your local changes to the following files would be overwritten by merge`, + # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. + repo.git.fetch('--all') + repo.git.reset('--hard', 'origin') def list_extensions(): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6671cb60..030f011e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list): continue try: - ext.pull() + ext.fetch_and_reset_hard() except Exception: - print(f"Error pulling updates for {ext.name}:", file=sys.stderr) + print(f"Error getting updates for {ext.name}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) shared.opts.disabled_extensions = disabled -- cgit v1.2.1 From d20dbe47e06de7f6c0e65242a04c9bb1410ef7cb Mon Sep 17 00:00:00 2001 From: Xu Cuijie <975114697@qq.com> Date: Sun, 13 Nov 2022 10:31:03 +0800 Subject: fix the model name error of Real-ESRGAN in the opts default value --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..c46c29f7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -299,7 +299,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), -- cgit v1.2.1 From 6fa891b934ba854efa87315baffc4ff458ab2539 Mon Sep 17 00:00:00 2001 From: KEV Date: Mon, 14 Nov 2022 00:25:38 +1000 Subject: Add 'Inpainting strength' to the 'generation_params' dictionary of 'infotext' which is saved into the 'params.txt' or png chunks. Value appears only if 'Denoising strength' appears too. --- modules/processing.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 03c9143d..01d7cbdc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -399,6 +399,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), + "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else shared.opts.inpainting_mask_weight), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, -- cgit v1.2.1 From 671c0e42b4167f4b7ff93e3b96922bf130c12718 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Sun, 13 Nov 2022 13:39:41 -0500 Subject: Fix docker tmp/ and extensions/ handling for docker. might also work for symlinks --- modules/ui_extensions.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6671cb60..95b63f24 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -9,6 +9,8 @@ import git import gradio as gr import html +import shutil +import errno from modules import extensions, shared, paths @@ -132,7 +134,18 @@ def install_extension_from_url(dirname, url): repo = git.Repo.clone_from(url, tmpdir) repo.remote().fetch() - os.rename(tmpdir, target_dir) + try: + os.rename(tmpdir, target_dir) + except OSError as err: + # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it + # Shouldn't cause any new issues at least but we probably want to handle it there too. + if err.errno == errno.EXDEV: + # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems + # Since we can't use a rename, do the slower but more versitile shutil.move() + shutil.move(tmpdir, target_dir) + else: + # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled. + raise(err) import launch launch.run_extension_installer(target_dir) -- cgit v1.2.1 From 9a1aff645a4bea745145c57c96950fbd3fcca27c Mon Sep 17 00:00:00 2001 From: parasi Date: Sun, 13 Nov 2022 13:44:27 -0600 Subject: resolve [name] after resolving [filewords] in training --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index eb75c376..06f271f9 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -97,13 +97,13 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') if shared.opts.tag_drop_out != 0: tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] if shared.opts.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) + text = text.replace("[name]", self.placeholder_token) return text def __len__(self): -- cgit v1.2.1 From 40ae95d53218b3b8f12fca50b5e4e98a1e50af4b Mon Sep 17 00:00:00 2001 From: KEV Date: Mon, 14 Nov 2022 18:05:59 +1000 Subject: Fix retrieving value for 'x/y plot' script. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 01d7cbdc..2fc9fe13 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -399,7 +399,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else shared.opts.inpainting_mask_weight), + "Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, -- cgit v1.2.1 From 3405acc6a4dcef2b73782a04924a9a12422e54f0 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Mon, 14 Nov 2022 14:07:13 -0600 Subject: Give --server-name priority over --listen and add check for --server-name in addition to --share and --listen --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..c628b580 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -106,7 +106,7 @@ restricted_opts = { "outdir_save", } -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) -- cgit v1.2.1 From 8f2ff861d31972d12de278075ea9c0c0deef99de Mon Sep 17 00:00:00 2001 From: Maiko Sinkyaet Tan Date: Tue, 15 Nov 2022 16:12:34 +0800 Subject: feat: add http basic authentication for api --- modules/api/api.py | 61 ++++++++++++++++++++++++++++++++++++------------------ modules/shared.py | 1 + 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..6bb01603 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,6 +5,9 @@ import uvicorn from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from secrets import compare_digest + import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -57,29 +60,47 @@ def encode_pil_to_base64(image): class Api: def __init__(self, app: FastAPI, queue_lock: Lock): + if shared.cmd_opts.api_auth: + self.credenticals = dict() + for auth in shared.cmd_opts.api_auth.split(","): + user, password = auth.split(":") + self.credenticals[user] = password + self.router = APIRouter() self.app = app self.queue_lock = queue_lock - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) - self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) - self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) - self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) - self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) - self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) - self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) - self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) - self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) - self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) - self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) - self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) + self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) + self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) + self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) + self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) + self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) + self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) + self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) + self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) + self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) + self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) + self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) + self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) + self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem]) + self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) + self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + + def add_api_route(self, path: str, endpoint, **kwargs): + if shared.cmd_opts.api_auth: + return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) + return self.app.add_api_route(path, endpoint, **kwargs) + + def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())): + if credenticals.username in self.credenticals: + if compare_digest(credenticals.password, self.credenticals[credenticals.username]): + return True + + raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) diff --git a/modules/shared.py b/modules/shared.py index 6936cbe0..62d526fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,6 +81,7 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--api-auth", type=str, help='Set authentication for api like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -- cgit v1.2.1 From 72b52fbb77360f848cfa296b0c79d2bc0a1060f2 Mon Sep 17 00:00:00 2001 From: dtlnor Date: Wed, 16 Nov 2022 13:08:03 +0900 Subject: add css override --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 5dce7f3b..5e2a992f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -69,8 +69,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None css_hide_progressbar = """ .wrap .m-12 svg { display:none!important; } .wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } .progress-bar { display:none!important; } .meta-text { display:none!important; } +.meta-text-center { display:none!important; } """ # Using constants for these since the variation selector isn't visible. -- cgit v1.2.1 From 9bbe1e3c2e54f64283bb333ebb648d8f40f5d4ee Mon Sep 17 00:00:00 2001 From: Llewellyn Pritchard Date: Wed, 16 Nov 2022 19:19:00 +0200 Subject: Fix unbounded prompt growth scripts that loop --- modules/processing.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 03c9143d..2fd12288 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -450,6 +450,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: modules.sd_hijack.model_hijack.clear_comments() comments = {} + prompt_tmp = p.prompt + negative_prompt_tmp = p.negative_prompt shared.prompt_styles.apply_styles(p) @@ -596,6 +598,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.postprocess(p, res) + p.prompt = prompt_tmp + p.negative_prompt = negative_prompt_tmp + return res -- cgit v1.2.1 From abfa22c16fb3d9b1ed8d049c7b68e94d1cca5b82 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:25:43 -0500 Subject: Revert "MPS Upscalers Fix" This reverts commit 768b95394a8500da639b947508f78296524f1836. --- modules/devices.py | 9 --------- modules/esrgan_model.py | 2 +- modules/scunet_model.py | 3 ++- modules/swinir_model.py | 2 +- 4 files changed, 4 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..a87d0d4c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -94,12 +94,3 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor - - -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c61669b4..9a9c38f1 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -199,7 +199,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 59532274..36a996bf 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -54,8 +54,9 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), device) + img = img.unsqueeze(0).to(device) + img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 4253b66d..facd262d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) + img = img.unsqueeze(0).to(devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old -- cgit v1.2.1 From a5106a7cdc24153332e4eb1d28e66ea1d7f1ef79 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:44:27 -0500 Subject: Remove extra .to(device) --- modules/scunet_model.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 36a996bf..52360241 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -56,7 +56,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(device) - img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() -- cgit v1.2.1 From c8c40c8a643f2d20e3475e4d9ae7aae6d36c7e85 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 17 Nov 2022 18:03:57 -0800 Subject: Add interrupt button to preprocessing --- modules/textual_inversion/ui.py | 2 +- modules/ui.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index d679e6f4..35c4feef 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old): def preprocess(*args): modules.textual_inversion.preprocess.preprocess(*args) - return "Preprocessing finished.", "" + return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", "" def train_embedding(*args): diff --git a/modules/ui.py b/modules/ui.py index 5dce7f3b..88e3c827 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1249,7 +1249,9 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - run_preprocess = gr.Button(value="Preprocess", variant='primary') + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt") + run_preprocess = gr.Button(value="Preprocess", variant='primary') process_split.change( fn=lambda show: gr_show(show), @@ -1422,6 +1424,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[], ) + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key, is_quicksettings=False): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default -- cgit v1.2.1 From 17e44328204a09653bb89eea18b7b489cc118703 Mon Sep 17 00:00:00 2001 From: killfrenzy96 Date: Fri, 18 Nov 2022 21:22:55 +1100 Subject: cleanly undo circular hijack #4818 --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 97979d05..eaedac13 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -96,8 +96,8 @@ class StableDiffusionModelHijack: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + self.apply_circular(False) self.layers = None - self.circular_enabled = False self.clip = None def apply_circular(self, enable): -- cgit v1.2.1 From 8ab4927452b04dcd30847eaf92ea7a9f3b9c74e1 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 22:54:09 +0700 Subject: Fix model wasn't restored even when choosing "None" --- modules/sd_vae.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..7a79239f 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -50,8 +50,8 @@ def delete_base_vae(): def restore_base_vae(model): - global base_vae, checkpoint_info if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + print("Restoring base VAE") load_vae_dict(model, base_vae) delete_base_vae() @@ -143,6 +143,7 @@ def load_vae(model, vae_file=None): vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} load_vae_dict(model, vae_dict_1) + store_base_vae(model) # If vae used is not in dict, update it # It will be removed on refresh though @@ -150,6 +151,9 @@ def load_vae(model, vae_file=None): if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) + # shared.opts.data['sd_vae'] = vae_opt + else: + restore_base_vae(model) loaded_vae_file = vae_file @@ -166,12 +170,8 @@ def load_vae(model, vae_file=None): # don't call this from outside -def load_vae_dict(model, vae_dict_1=None): - if vae_dict_1: - store_base_vae(model) - model.first_stage_model.load_state_dict(vae_dict_1) - else: - restore_base_vae() +def load_vae_dict(model, vae_dict_1): + model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.1 From abc1e79a5da24a1ea0f4bceedcdf225f32010aa8 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 11:10:53 +0700 Subject: Fix base VAE caching was done after loading VAE, also add safeguard --- modules/sd_models.py | 1 + modules/sd_vae.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..e4dba62c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 7a79239f..dd69a5e6 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -15,7 +15,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_dict = {"auto": "auto", "None": None, None: None} default_vae_list = ["auto", "None"] @@ -39,6 +39,7 @@ def get_base_vae(model): def store_base_vae(model): global base_vae, checkpoint_info if checkpoint_info != model.sd_checkpoint_info: + assert not loaded_vae_file, "Trying to store non-base VAE!" base_vae = model.first_stage_model.state_dict().copy() checkpoint_info = model.sd_checkpoint_info @@ -50,9 +51,11 @@ def delete_base_vae(): def restore_base_vae(model): + global loaded_vae_file if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: print("Restoring base VAE") load_vae_dict(model, base_vae) + loaded_vae_file = None delete_base_vae() @@ -140,10 +143,10 @@ def load_vae(model, vae_file=None): if vae_file: print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} load_vae_dict(model, vae_dict_1) - store_base_vae(model) # If vae used is not in dict, update it # It will be removed on refresh though @@ -157,15 +160,6 @@ def load_vae(model, vae_file=None): loaded_vae_file = vae_file - """ - # Save current VAE to VAE settings, maybe? will it work? - if save_settings: - if vae_file is None: - vae_opt = "None" - - # shared.opts.sd_vae = vae_opt - """ - first_load = False @@ -174,6 +168,9 @@ def load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) +def clear_loaded_vae(): + global loaded_vae_file + loaded_vae_file = None def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack -- cgit v1.2.1 From c7be83bf0240498d9382e2afeaa3f0677d26c7f6 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 11:11:14 +0700 Subject: Misc Misc --- modules/sd_models.py | 1 + modules/sd_vae.py | 3 +-- modules/shared.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index e4dba62c..cd7fe37a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index dd69a5e6..13bf3d31 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -154,8 +154,7 @@ def load_vae(model, vae_file=None): if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) - # shared.opts.data['sd_vae'] = vae_opt - else: + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file diff --git a/modules/shared.py b/modules/shared.py index 17132e42..a9daf800 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -335,7 +335,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.1 From 9fdc343dcaee70f1a0ff15c0cc668dbd487abc61 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 17 Nov 2022 18:04:10 +0700 Subject: Fix model caching requiring deepcopy --- modules/sd_vae.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 13bf3d31..5b4709b5 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -4,6 +4,7 @@ from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path import glob +from copy import deepcopy model_dir = "Stable-diffusion" @@ -40,7 +41,7 @@ def store_base_vae(model): global base_vae, checkpoint_info if checkpoint_info != model.sd_checkpoint_info: assert not loaded_vae_file, "Trying to store non-base VAE!" - base_vae = model.first_stage_model.state_dict().copy() + base_vae = deepcopy(model.first_stage_model.state_dict()) checkpoint_info = model.sd_checkpoint_info -- cgit v1.2.1 From 028b67b6357b5a00ccbd6ea72d2f244a6664162b Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 19 Nov 2022 01:27:54 +0700 Subject: Use underscore naming for "private" functions in sd_vae --- modules/sd_vae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 5b4709b5..d82a7bad 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -55,7 +55,7 @@ def restore_base_vae(model): global loaded_vae_file if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: print("Restoring base VAE") - load_vae_dict(model, base_vae) + _load_vae_dict(model, base_vae) loaded_vae_file = None delete_base_vae() @@ -147,7 +147,7 @@ def load_vae(model, vae_file=None): store_base_vae(model) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - load_vae_dict(model, vae_dict_1) + _load_vae_dict(model, vae_dict_1) # If vae used is not in dict, update it # It will be removed on refresh though @@ -164,7 +164,7 @@ def load_vae(model, vae_file=None): # don't call this from outside -def load_vae_dict(model, vae_dict_1): +def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.1 From 0663706d4405b4f76ce653097f4f8989ee8b8684 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 13:47:03 +0700 Subject: Option to use selected VAE as default fallback instead of primary option --- modules/sd_vae.py | 25 ++++++++++++++++--------- modules/shared.py | 1 + 2 files changed, 17 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..0b5f0213 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -83,7 +83,19 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): return vae_list -def resolve_vae(checkpoint_file, vae_file="auto"): +def get_vae_from_settings(vae_file="auto"): + # else, we load from settings, if not set to be default + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + return vae_file + + +def resolve_vae(checkpoint_file=None, vae_file="auto"): global first_load, vae_dict, vae_list # if vae_file argument is provided, it takes priority, but not saved @@ -98,14 +110,9 @@ def resolve_vae(checkpoint_file, vae_file="auto"): shared.opts.data['sd_vae'] = get_filename(vae_file) else: print("VAE provided as command line argument doesn't exist") - # else, we load from settings - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print("Selected VAE doesn't exist") + # fallback to selector in settings, if vae selector not set to act as default fallback + if not shared.opts.sd_vae_as_default: + vae_file = get_vae_from_settings(vae_file) # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): diff --git a/modules/shared.py b/modules/shared.py index 17132e42..b84767f0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -336,6 +336,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae_as_default": OptionInfo(False, "Use selected VAE as default fallback instead"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.1 From 2c5ca706a7e624d268545ba3318ba230b7b33477 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:55:47 +0700 Subject: Remove no longer necessary parts and add vae_file safeguard --- modules/sd_models.py | 10 ++-------- modules/sd_vae.py | 1 + 2 files changed, 3 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..c59151e0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): cache_enabled = shared.opts.sd_checkpoint_cache > 0 - if cache_enabled: - sd_vae.restore_base_vae(model) - - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if cache_enabled and checkpoint_info in checkpoints_loaded: # use checkpoint cache - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + print(f"Loading weights [{sd_model_hash}] from cache") model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file @@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..8bdb2c17 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -139,6 +139,7 @@ def load_vae(model, vae_file=None): # save_settings = False if vae_file: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} -- cgit v1.2.1 From 271fd2d700a59e80d9dc9f23ad3ef08c988e8b24 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:58:15 +0700 Subject: More verbose messages --- modules/sd_vae.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 8bdb2c17..fa8de905 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -89,15 +89,15 @@ def resolve_vae(checkpoint_file, vae_file="auto"): # if vae_file argument is provided, it takes priority, but not saved if vae_file and vae_file not in default_vae_list: if not os.path.isfile(vae_file): + print(f"VAE provided as function argument doesn't exist: {vae_file}") vae_file = "auto" - print("VAE provided as function argument doesn't exist") # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported if first_load and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path shared.opts.data['sd_vae'] = get_filename(vae_file) else: - print("VAE provided as command line argument doesn't exist") + print(f"VAE provided as command line argument doesn't exist: {vae_file}") # else, we load from settings if vae_file == "auto" and shared.opts.sd_vae is not None: # if saved VAE settings isn't recognized, fallback to auto @@ -105,25 +105,25 @@ def resolve_vae(checkpoint_file, vae_file="auto"): # if VAE selected but not found, fallback to auto if vae_file not in default_vae_values and not os.path.isfile(vae_file): vae_file = "auto" - print("Selected VAE doesn't exist") + print(f"Selected VAE doesn't exist: {vae_file}") # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path - print("Using VAE provided as command line argument") + print(f"Using VAE provided as command line argument: {vae_file}") # if still not found, try look for ".vae.pt" beside model model_path = os.path.splitext(checkpoint_file)[0] if vae_file == "auto": vae_file_try = model_path + ".vae.pt" if os.path.isfile(vae_file_try): vae_file = vae_file_try - print("Using VAE found beside selected model") + print(f"Using VAE found similar to selected model: {vae_file}") # if still not found, try look for ".vae.ckpt" beside model if vae_file == "auto": vae_file_try = model_path + ".vae.ckpt" if os.path.isfile(vae_file_try): vae_file = vae_file_try - print("Using VAE found beside selected model") + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None -- cgit v1.2.1 From c8f7b5cdd73969d3d5027ceb71cbbd83d557702b Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 11:11:14 +0700 Subject: Misc Misc --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 17132e42..a9daf800 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -335,7 +335,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.1 From d9fd4525a5d684100997130cc4132736bab1e4d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 11:09:44 +0300 Subject: change text for sd_vae_as_default that makes more sense to me --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 5528ab15..1c42641d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -335,7 +335,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), - "sd_vae_as_default": OptionInfo(False, "Use selected VAE as default fallback instead"), + "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.1 From cdc8020d13c5eef099c609b0a911ccf3568afc0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 12:01:51 +0300 Subject: change StableDiffusionProcessing to internally use sampler name instead of sampler index --- modules/api/api.py | 26 ++++++++--------------- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/images.py | 2 +- modules/img2img.py | 4 ++-- modules/processing.py | 29 +++++++++++--------------- modules/sd_samplers.py | 13 +++++++++--- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/txt2img.py | 3 ++- modules/ui.py | 2 +- 9 files changed, 41 insertions(+), 46 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..0eccccbb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,9 +6,9 @@ from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -25,8 +25,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -82,14 +86,9 @@ class Api: self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -109,12 +108,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -123,10 +116,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -272,7 +264,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7f182712..fbb87dd1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared +from modules import devices, processing, sd_models, shared, sd_samplers from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_index = preview_sampler_index + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/images.py b/modules/images.py index ae705cbd..26d5b7a9 100644 --- a/modules/images.py +++ b/modules/images.py @@ -303,7 +303,7 @@ class FilenameGenerator: 'width': lambda self: self.image.width, 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), - 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), + 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime