aboutsummaryrefslogtreecommitdiff
path: root/modules/extras.py
diff options
context:
space:
mode:
authorspace-nuko <24979496+space-nuko@users.noreply.github.com>2023-04-02 17:41:55 -0500
committerspace-nuko <24979496+space-nuko@users.noreply.github.com>2023-04-02 17:41:55 -0500
commitd132481058f8a827cd407f2121f128a2bb862f7a (patch)
tree289f5462847c1fafeec3b8218028b0b813dbe8af /modules/extras.py
parent22bcc7be428c94e9408f589966c2040187245d81 (diff)
Embed model merge metadata in .safetensors file
Diffstat (limited to 'modules/extras.py')
-rw-r--r--modules/extras.py44
1 files changed, 42 insertions, 2 deletions
diff --git a/modules/extras.py b/modules/extras.py
index d8ece955..77d88592 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,6 +1,7 @@
import os
import re
import shutil
+import json
import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -241,13 +242,52 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
+ metadata = {"format": "pt", "models": {}, "merge_recipe": None}
+
+ if save_metadata:
+ merge_recipe = {
+ "primary_model_hash": primary_model_info.sha256,
+ "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+ "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+ "interp_method": interp_method,
+ "multiplier": multiplier,
+ "save_as_half": save_as_half,
+ "custom_name": custom_name,
+ "config_source": config_source,
+ "bake_in_vae": bake_in_vae,
+ "discard_weights": discard_weights,
+ "is_inpainting": result_is_inpainting_model,
+ "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+ }
+ metadata["merge_recipe"] = json.dumps(merge_recipe)
+
+ def add_model_metadata(checkpoint_info):
+ metadata["models"][checkpoint_info.sha256] = {
+ "name": checkpoint_info.name,
+ "legacy_hash": checkpoint_info.hash,
+ "merge_recipe": checkpoint_info.metadata.get("merge_recipe", None)
+ }
+
+ metadata["models"].update(checkpoint_info.metadata.get("models", {}))
+
+ add_model_metadata(primary_model_info)
+ if secondary_model_info:
+ add_model_metadata(secondary_model_info)
+ if tertiary_model_info:
+ add_model_metadata(tertiary_model_info)
+
+ metadata["models"] = json.dumps(metadata["models"])
+
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
torch.save(theta_0, output_modelname)
sd_models.list_models()
+ created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+ if created_model:
+ created_model.calculate_shorthash()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)