aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorAngelBottomless <35677394+aria1th@users.noreply.github.com>2022-10-26 22:24:33 +0900
committeraria1th <35677394+aria1th@users.noreply.github.com>2022-10-27 14:38:21 +0900
commit85fcccc105aa50f1d78de559233eaa9f384608b5 (patch)
tree7fedce39b4735f852c3db5cf8858a472fc7d96b4 /modules/hypernetworks/hypernetwork.py
parentb6a8bb123bd519736306417399f6441e504f1e8b (diff)
Squashed commit of fixing dropout silently
fix dropouts for future hypernetworks add kwargs for Hypernetwork class hypernet UI for gradio input add recommended options remove as options revert adding options in ui
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py25
1 files changed, 17 insertions, 8 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 2997cead..dd921153 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -60,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
- if use_dropout and i < len(layer_structure) - 3:
+ if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
@@ -126,7 +127,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
@@ -139,11 +140,14 @@ class Hypernetwork:
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
+ self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
def weights(self):
@@ -172,7 +176,8 @@ class Hypernetwork:
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
-
+ state_dict['last_layer_dropout'] = self.last_layer_dropout
+
torch.save(state_dict, filename)
def load(self, filename):
@@ -193,12 +198,16 @@ class Hypernetwork:
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
+ print(f"Activate last layer is set to {self.activate_output}")
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)