aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/hypernetworks/hypernetwork.py19
-rw-r--r--modules/textual_inversion/learn_schedule.py34
-rw-r--r--modules/textual_inversion/textual_inversion.py44
-rw-r--r--modules/ui.py2
4 files changed, 54 insertions, 45 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5608e799..470659df 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -14,6 +14,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class HypernetworkModule(torch.nn.Module):
@@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
for weight in weights:
weight.requires_grad = True
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
@@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > steps:
- break
+ if hypernetwork.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except Exception:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..db720271
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,34 @@
+
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 47a27faf..7717837d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,6 +10,7 @@ import datetime
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class Embedding:
@@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
-
- scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(scheduleIter)
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
@@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > end_step:
try:
- (learn_rate, end_step) = next(scheduleIter)
+ (learn_rate, end_step) = next(schedules)
except:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
@@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.save(filename)
return embedding, filename
-
-class LearnSchedule:
- def __init__(self, learn_rate, max_steps, cur_step=0):
- pairs = learn_rate.split(',')
- self.rates = []
- self.it = 0
- self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
- return
- elif step == -1:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.it < self.maxit:
- self.it += 1
- return self.rates[self.it - 1]
- else:
- raise StopIteration
diff --git a/modules/ui.py b/modules/ui.py
index 2b688e32..1204eef7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03")
+ learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))