aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py4
-rw-r--r--modules/textual_inversion/learn_schedule.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py24
3 files changed, 15 insertions, 15 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 5b1c5002..8bb00d27 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(len(self.dataset))
+ self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
+ self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text):
text = random.choice(self.lines)
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..3a736065 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -52,7 +52,7 @@ class LearnRateScheduler:
self.finished = False
def apply(self, optimizer, step_number):
- if step_number <= self.end_step:
+ if step_number < self.end_step:
return
try:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ff002d3e..17dfb223 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
- if step % shared.opts.training_write_csv_every != 0:
+ if (step + 1) % shared.opts.training_write_csv_every != 0:
return
-
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = step // epoch_len
- epoch_step = step - epoch * epoch_len
+ epoch_step = step % epoch_len
csv_writer.writerow({
"step": step + 1,
- "epoch": epoch + 1,
+ "epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
@@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
+ steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+ epoch_step = embedding.step % len(ds)
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{embedding.step}'
+ embedding.name = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
"learn_rate": scheduler.learn_rate
})
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- forced_filename = f'{embedding_name}-{embedding.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
@@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)