From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:49:29 +0700 Subject: Add missing info on hypernetwork/embedding model log Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513 Also group the saving into one --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++------- modules/textual_inversion/textual_inversion.py | 39 +++++++++++++++++--------- 2 files changed, 47 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..86daf825 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log images_dir = None hypernetwork = shared.loaded_hypernetwork + checkpoint = sd_models.select_checkpoint() ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", @@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}
""" report_statistics(loss_dict) - checkpoint = sd_models.select_checkpoint() - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). - hypernetwork.name = hypernetwork_name - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(filename) + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) return hypernetwork, filename + +def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): + old_hypernetwork_name = hypernetwork.name + old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None + old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None + try: + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.name = hypernetwork_name + hypernetwork.save(filename) + except: + hypernetwork.sd_checkpoint = old_sd_checkpoint + hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name + hypernetwork.name = old_hypernetwork_name + raise diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..ee9917ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -119,7 +119,7 @@ class EmbeddingDatabase: vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('hash', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) self.register_embedding(embedding, shared.sd_model) @@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 if ititial_step > steps: @@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { @@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}

""" - checkpoint = sd_models.select_checkpoint() - - embedding.sd_checkpoint = checkpoint.hash - embedding.sd_checkpoint_name = checkpoint.model_name - embedding.cached_checksum = None - # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). - embedding.name = embedding_name - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt') - embedding.save(filename) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) return embedding, filename + +def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): + old_embedding_name = embedding.name + old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None + old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None + old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None + try: + embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint_name = checkpoint.model_name + if remove_cached_checksum: + embedding.cached_checksum = None + embedding.name = embedding_name + embedding.save(filename) + except: + embedding.sd_checkpoint = old_sd_checkpoint + embedding.sd_checkpoint_name = old_sd_checkpoint_name + embedding.name = old_embedding_name + embedding.cached_checksum = old_cached_checksum + raise -- cgit v1.2.1