From c7e50425f63c07242068f8dcccce70a4ef28a17f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 09:25:37 +0300 Subject: add progress bar to modelmerger --- modules/extras.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'modules/extras.py') diff --git a/modules/extras.py b/modules/extras.py index 367c15cc..034f28e4 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): shared.state.begin() shared.state.job = 'model-merge' + shared.state.job_count = 1 def fail(message): shared.state.textinfo = message shared.state.end() - return [message, *[gr.update() for _ in range(4)]] + return [*[gr.update() for _ in range(4)], message] def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') if theta_func1: + shared.state.job_count += 1 + print(f"Loading {tertiary_model_info.filename}...") theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): if 'model' in key: if key in theta_2: @@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_1[key] = theta_func1(theta_1[key], t2) else: theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 del theta_2 + shared.state.nextjob() + shared.state.textinfo = f"Loading {primary_model_info.filename}..." print(f"Loading {primary_model_info.filename}...") theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') @@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + shared.state.sampling_steps = len(theta_0.keys()) for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: @@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam if save_as_half: theta_0[key] = theta_0[key].half() + shared.state.sampling_step += 1 + # I believe this part should be discarded, but I'll leave it for now until I am sure for key in theta_1.keys(): if 'model' in key and key not in theta_0: @@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam output_modelname = os.path.join(ckpt_dir, filename) + shared.state.nextjob() shared.state.textinfo = f"Saving to {output_modelname}..." print(f"Saving to {output_modelname}...") @@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.end() - return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.1