aboutsummaryrefslogtreecommitdiff
path: root/modules/extras.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-11 09:10:07 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-11 09:10:07 +0300
commit954091697fce7a1b7997d5f3d73551f793f6bebc (patch)
treeaf9b1ce646d3a55b26438221f3a19aa5aacd4815 /modules/extras.py
parent3e20244b0fea10988cf5ad8a2fbe190ac47a5049 (diff)
add an option to copy config from one of models in checkpoint merger
Diffstat (limited to 'modules/extras.py')
-rw-r--r--modules/extras.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/modules/extras.py b/modules/extras.py
index 7407bfe3..a03d558e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -3,6 +3,7 @@ import math
import os
import sys
import traceback
+import shutil
import numpy as np
from PIL import Image
@@ -248,7 +249,32 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+def create_config(ckpt_result, config_source, a, b, c):
+ def config(x):
+ return sd_models.find_checkpoint_config(x) if x else None
+
+ if config_source == 0:
+ cfg = config(a) or config(b) or config(c)
+ elif config_source == 1:
+ cfg = config(b)
+ elif config_source == 2:
+ cfg = config(c)
+ else:
+ cfg = None
+
+ if cfg is None:
+ return
+
+ filename, _ = os.path.splitext(ckpt_result)
+ checkpoint_filename = filename + ".yaml"
+
+ print("Copying config:")
+ print(" from:", cfg)
+ print(" to:", checkpoint_filename)
+ shutil.copyfile(cfg, checkpoint_filename)
+
+
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models()
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
+
print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()