aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py36
-rw-r--r--modules/textual_inversion/preprocess.py5
-rw-r--r--modules/textual_inversion/textual_inversion.py20
-rw-r--r--modules/textual_inversion/ui.py2
4 files changed, 41 insertions, 22 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index bcf772d2..f61f40d3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,14 +8,14 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
self.placeholder_token = placeholder_token
@@ -32,12 +32,15 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
@@ -52,7 +55,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ if include_cond:
+ text = self.create_text(filename_tokens)
+ cond = cond_model([text]).to(devices.cpu)
+ else:
+ cond = None
+
+ self.dataset.append((init_latent, filename_tokens, cond))
self.length = len(self.dataset) * repeats
@@ -63,6 +72,12 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ def create_text(self, filename_tokens):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+ return text
+
def __len__(self):
return self.length
@@ -71,10 +86,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
-
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ x, filename_tokens, cond = self.dataset[index]
- return x, text
+ text = self.create_text(filename_tokens)
+ return x, text, cond
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index d7efdef2..1a672725 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -46,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
if shared.state.interrupted:
break
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c64a4598..47a27faf 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -208,7 +208,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
+ for i, (x, text, _) in pbar:
embedding.step = i + ititial_step
if embedding.step > end_step:
@@ -236,10 +236,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward()
optimizer.step()
- epoch_num = embedding.step // epoch_len
- epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -248,12 +248,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
+ prompt=preview_text,
steps=20,
- height=training_height,
- width=training_width,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
@@ -264,7 +266,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..70f47343 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -23,6 +23,8 @@ def preprocess(*args):
def train_embedding(*args):
+ assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible'
+
try:
sd_hijack.undo_optimizations()