aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/extras.py44
-rw-r--r--modules/sd_models.py11
-rw-r--r--modules/ui.py4
3 files changed, 55 insertions, 4 deletions
diff --git a/modules/extras.py b/modules/extras.py
index d8ece955..77d88592 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,6 +1,7 @@
import os
import re
import shutil
+import json
import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -241,13 +242,52 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
+ metadata = {"format": "pt", "models": {}, "merge_recipe": None}
+
+ if save_metadata:
+ merge_recipe = {
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "interp_method": interp_method,
+ "multiplier": multiplier,
+ "save_as_half": save_as_half,
+ "custom_name": custom_name,
+ "config_source": config_source,
+ "bake_in_vae": bake_in_vae,
+ "discard_weights": discard_weights,
+ "is_inpainting": result_is_inpainting_model,
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+ }
+ metadata["merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ metadata["models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "merge_recipe": checkpoint_info.metadata.get("merge_recipe", None)
+ }
+
+ metadata["models"].update(checkpoint_info.metadata.get("models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+
+ metadata["models"] = json.dumps(metadata["models"])
+
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ea874df..4f7613a1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -52,6 +52,15 @@ class CheckpointInfo:
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.metadata = {}
+
+ _, ext = os.path.splitext(self.filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading checkpoint metadata: {filename}")
+
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
@@ -544,4 +553,4 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
- return sd_model \ No newline at end of file
+ return sd_model
diff --git a/modules/ui.py b/modules/ui.py
index 627fbe0b..64fb93c3 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1019,8 +1019,9 @@ def create_ui():
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
+ save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
with FormRow():
with gr.Column():
@@ -1658,6 +1659,7 @@ def create_ui():
config_source,
bake_in_vae,
discard_weights,
+ save_metadata,
],
outputs=[
primary_model_name,