aboutsummaryrefslogtreecommitdiff
path: root/modules/ui_checkpoint_merger.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-01 08:27:54 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-01 08:27:54 +0300
commit07be13caa357b14f6afa247566d53339522b8e66 (patch)
tree4a9e2329ad89f4a7aa26cfc338cd0ccf579f3fc0 /modules/ui_checkpoint_merger.py
parent6d3a0c950626e887f20bfc9946b84f9685303bab (diff)
add metadata to checkpoint merger
Diffstat (limited to 'modules/ui_checkpoint_merger.py')
-rw-r--r--modules/ui_checkpoint_merger.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py
index 8e72258a..4863d861 100644
--- a/modules/ui_checkpoint_merger.py
+++ b/modules/ui_checkpoint_merger.py
@@ -51,7 +51,6 @@ class UiCheckpointMerger:
with FormRow():
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- self.save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
with FormRow():
with gr.Column():
@@ -65,16 +64,30 @@ class UiCheckpointMerger:
with FormRow():
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
- with gr.Row():
+ with gr.Accordion("Metadata", open=False) as metadata_editor:
+ with FormRow():
+ self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
+ self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
+ self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
+
+ self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
+ self.read_metadata = gr.Button("Read metadata from selected checkpoints")
+
+ with FormRow():
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
with gr.Group(elem_id="modelmerger_results_panel"):
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
+ self.metadata_editor = metadata_editor
self.blocks = modelmerger_interface
def setup_ui(self, dummy_component, sd_model_checkpoint_component):
+ self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
+
+ self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
+
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
self.modelmerger_merge.click(
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
@@ -93,6 +106,9 @@ class UiCheckpointMerger:
self.bake_in_vae,
self.discard_weights,
self.save_metadata,
+ self.add_merge_recipe,
+ self.copy_metadata_fields,
+ self.metadata_json,
],
outputs=[
self.primary_model_name,