aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py58
-rw-r--r--modules/textual_inversion/preprocess.py14
-rw-r--r--modules/textual_inversion/textual_inversion.py24
3 files changed, 74 insertions, 22 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index d31963d4..af9fbcf2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
self.filename = filename
self.filename_text = filename_text
+ self.weight = weight
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
@@ -30,7 +31,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
+ alpha_channel = None
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB')
+ image = Image.open(path)
+ #Currently does not work for single color transparency
+ #We would need to read image.info['transparency'] for that
+ if use_weight and 'A' in image.getbands():
+ alpha_channel = image.getchannel('A')
+ image = image.convert('RGB')
if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
@@ -87,17 +94,35 @@ class PersonalizedBase(Dataset):
with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
- if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- latent_sampling_method = "once"
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
- elif latent_sampling_method == "deterministic":
- # Works only for DiagonalGaussianDistribution
- latent_dist.std = 0
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
- elif latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+ #Perform latent sampling, even for random sampling.
+ #We need the sample dimensions for the weights
+ if latent_sampling_method == "deterministic":
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
+ # Works only for DiagonalGaussianDistribution
+ latent_dist.std = 0
+ else:
+ latent_sampling_method = "once"
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+
+ if use_weight and alpha_channel is not None:
+ channels, *latent_size = latent_sample.shape
+ weight_img = alpha_channel.resize(latent_size)
+ npweight = np.array(weight_img).astype(np.float32)
+ #Repeat for every channel in the latent sample
+ weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
+ #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
+ weight -= weight.min()
+ weight /= weight.mean()
+ elif use_weight:
+ #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
+ weight = torch.ones(latent_sample.shape)
+ else:
+ weight = None
+
+ if latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
+ else:
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
@@ -110,6 +135,7 @@ class PersonalizedBase(Dataset):
del torchdata
del latent_dist
del latent_sample
+ del weight
self.length = len(self.dataset)
self.groups = list(groups.values())
@@ -195,6 +221,10 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ if all(entry.weight is not None for entry in data):
+ self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
+ else:
+ self.weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 2239cb84..4a29151d 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -11,7 +11,7 @@ from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -19,7 +19,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height,
if process_caption_deepbooru:
deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
finally:
@@ -131,7 +131,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
return wh and center_crop(image, *wh)
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -161,7 +161,9 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
params.subindex = 0
filename = os.path.join(src, imagefile)
try:
- img = Image.open(filename).convert("RGB")
+ img = Image.open(filename)
+ img = ImageOps.exif_transpose(img)
+ img = img.convert("RGB")
except Exception:
continue
@@ -223,6 +225,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
process_default_resize = False
+ if process_keep_original_size:
+ save_pic(img, index, params, existing_caption=existing_caption)
+ process_default_resize = False
+
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, params, existing_caption=existing_caption)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index a1a406c2..379df243 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -152,7 +152,11 @@ class EmbeddingDatabase:
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
+ if data:
+ name = data.get('name', name)
+ else:
+ # if data is None, means this is not an embeding, just a preview image
+ return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
@@ -229,6 +233,12 @@ class EmbeddingDatabase:
self.load_from_dir(embdir)
embdir.update()
+ # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
+ # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
+ sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
+ self.word_embeddings.clear()
+ self.word_embeddings.update(sorted_word_embeddings)
+
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
@@ -351,7 +361,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
@@ -410,7 +420,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
@@ -480,6 +490,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
if is_training_inpainting_model:
@@ -490,7 +502,11 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
else:
cond = c
- loss = shared.sd_model(x, cond)[0] / gradient_step
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, cond)[0] / gradient_step
del x
_loss_step += loss.item()