aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks
diff options
context:
space:
mode:
authortimntorres <timothynarcisotorres@gmail.com>2022-10-21 02:14:02 -0700
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-21 16:52:24 +0300
commit19818f023cfafc472c6c241cab0b72896a168481 (patch)
treeda91951c068afb4412e9ef07bb175768fc51fa40 /modules/hypernetworks
parent51e3dc9ccad157d7161b697a246e26c868d46a7c (diff)
Match hypernet name with filename in all cases.
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b1a5d0c7..6d392be4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ temp = hypernetwork.name
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
@@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
+ hypernetwork.name = hypernetwork_name
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename