aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-10-29 15:37:24 +0700
committerMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-10-29 15:37:24 +0700
commita5f3adbdd7d9b8245f7782216ac48913660e6bb5 (patch)
tree78ced0ec61f5bd738c11a1d10cfbe7cf3fc509f8 /modules
parent35c45df28b303a05d56a13cb56d4046f08cf8c25 (diff)
Allow trailing comma in learning rate
Diffstat (limited to 'modules')
-rw-r--r--modules/textual_inversion/learn_schedule.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 3a736065..76e611b6 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -11,23 +11,30 @@ class LearnScheduleIterator:
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception("Invalid learning rate schedule")
+
def __iter__(self):
return self