aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py44
1 files changed, 25 insertions, 19 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 83cbb4f0..3aebefa8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -225,7 +225,7 @@ class Hypernetwork:
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['hash'] = self.shorthash()
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
@@ -237,32 +237,33 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- print(self.layer_structure)
- optional_info = state_dict.get('optional_info', None)
- if optional_info is not None:
- print(f"INFO:\n {optional_info}\n")
- self.optional_info = optional_info
+ self.optional_info = state_dict.get('optional_info', None)
self.activation_func = state_dict.get('activation_func', None)
- print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
- print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
- print(f"Layer norm is set to {self.add_layer_norm}")
self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
- print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
- print("Using previous dropout structure")
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- print(f"Dropout structure is set to {self.dropout_structure}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
- if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
@@ -289,6 +290,11 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+
+ return sha256[0:10]
+
def list_hypernetworks(path):
res = {}
@@ -296,7 +302,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name + f"({sd_models.model_hash(filename)})"] = filename
+ res[name] = filename
return res
@@ -509,7 +515,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
@@ -737,7 +743,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
- hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)