aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authoraria1th <35677394+aria1th@users.noreply.github.com>2022-11-03 14:49:26 +0900
committeraria1th <35677394+aria1th@users.noreply.github.com>2022-11-03 14:49:26 +0900
commit1764ac3c8bc482bd575987850e96630d9115e51a (patch)
tree2e90525f56fbbcfbb19f9e884e99559900ac9262 /modules
parent0b143c1163a96b193a4e8512be9c5831c661a50d (diff)
use hash to check valid optim
Diffstat (limited to 'modules')
-rw-r--r--modules/hypernetworks/hypernetwork.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 63c25de8..4230b8cf 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -177,12 +177,13 @@ class Hypernetwork:
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
+
+ torch.save(state_dict, filename)
if self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
- torch.save(state_dict, filename)
-
def load(self, filename):
self.filename = filename
if self.name is None:
@@ -204,7 +205,10 @@ class Hypernetwork:
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
- self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
@@ -229,7 +233,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name] = filename
+ res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
@@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
hypernetwork_dir = None
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)