aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
authorMalumaDev <piano.lu92@gmail.com>2022-10-14 10:56:41 +0200
committerMalumaDev <piano.lu92@gmail.com>2022-10-14 10:56:41 +0200
commitbb57f30c2de46cfca5419ad01738a41705f96cc3 (patch)
tree7e47bc282de81a8011ea140f8a850652253b0e18 /modules/textual_inversion
parentfdecb636855748e03efc40c846a0043800aadfcc (diff)
init
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py35
2 files changed, 26 insertions, 11 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 67e90afe..59b2b021 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -48,7 +48,7 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..b12a8e6d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def batched(dataset, total, n=1):
+ for ndx in range(0, total, n):
+ yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps,
+ create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding,
+ preview_image_prompt, batch_size=1,
+ gradient_accumulation=1):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
+ height=training_height,
+ repeats=shared.opts.training_image_repeats_per_epoch,
+ placeholder_token=embedding_name, model=shared.sd_model,
+ device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step)
for i, entry in pbar:
embedding.step = i + ititial_step
@@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
with torch.autocast("cuda"):
- c = cond_model([entry.cond_text])
+ c = cond_model([e.cond_text for e in entry])
+
+ x = torch.stack([e.latent for e in entry]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad()
loss.backward()
- optimizer.step()
+ if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step):
+ optimizer.step()
+ optimizer.zero_grad()
+
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
@@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(entry.cond_text)}<br/>
+Last prompt: {html.escape(entry[-1].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>