aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack_unet.py7
-rw-r--r--modules/ui_extra_networks.py9
2 files changed, 12 insertions, 4 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 88c94e54..a6ee577c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+ if isinstance(cond, dict):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 8b4f97f8..c6ff889a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -117,8 +117,13 @@ def create_ui(container, button, tabname):
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
- button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
- button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+ def toggle_visibility(is_visible):
+ is_visible = not is_visible
+ return is_visible, gr.update(visible=is_visible)
+
+ state_visible = gr.State(value=False)
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
+ button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
def refresh():
res = []