aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorFampai <>2022-10-10 17:10:29 -0400
committerFampai <>2022-10-10 17:10:29 -0400
commit2536ecbb1790da2af0d61b6a26f38732cba665cd (patch)
tree98952174ed80c0aa376c433d5cd6f4a500b4b18f /modules/textual_inversion/textual_inversion.py
parentce37fdd30e9fc0fe0bc5805a068ce8b11b42b5a3 (diff)
Refactored learning rate code
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py51
1 files changed, 47 insertions, 4 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..c64a4598 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
@@ -203,12 +201,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
epoch_len = (tr_img_len * num_repeats) + tr_img_len
+ scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(scheduleIter)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
- if embedding.step > steps:
- break
+ if embedding.step > end_step:
+ try:
+ (learn_rate, end_step) = next(scheduleIter)
+ except:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
@@ -277,3 +287,36 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration