aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-10-31 13:53:22 +0700
committerMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-10-31 13:53:22 +0700
commit4123be632a98f70cda06e14c2f556f7ad38cd436 (patch)
tree2d7ac1c59f95a1509d31c05cfff01813f6410164 /modules/hypernetworks/hypernetwork.py
parent840307f23738c38f7ac3ad636e53ccec66e71f8b (diff)
Fix merge conflicts
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 65a584bb..207808ee 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -373,6 +373,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ clip_grad_mode_value = clip_grad_mode == "value"
+ clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
+
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
@@ -389,21 +395,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
- last_saved_file = "<none>"
- last_saved_image = "<none>"
- forced_filename = "<none>"
-
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
weights = hypernetwork.weights()
for weight in weights: