aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/ui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-21 08:36:07 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-21 08:36:07 +0300
commit40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 (patch)
tree75230fd3b1c9c53593aca65d7260b62b6df2d82b /modules/hypernetworks/ui.py
parente33cace2c2074ef342d027c1f31ffc4b3c3e877e (diff)
extra networks UI
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
Diffstat (limited to 'modules/hypernetworks/ui.py')
-rw-r--r--modules/hypernetworks/ui.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 81e3f519..76599f5a 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
@@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(*args):
-
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception:
raise
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()