aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorNerogar <nerogar@arcor.de>2022-10-23 14:05:25 +0200
committerNerogar <nerogar@arcor.de>2022-11-01 21:02:07 +0100
commitcffc240a7327ae60671ff533469fc4ed4bf605de (patch)
tree4441193674eb53bb3a78ac6d9855148e84cb8624 /modules
parent198a1ffcfc963a3d74674fad560e87dbebf7949f (diff)
fixed textual inversion training with inpainting models
Diffstat (limited to 'modules')
-rw-r--r--modules/textual_inversion/textual_inversion.py27
1 files changed, 26 insertions, 1 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0aeb0459..2630c7c9 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -224,6 +224,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+def create_dummy_mask(x, width=None, height=None):
+ if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
+
+
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -286,6 +306,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
+ img_c = None
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
@@ -299,8 +320,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ loss = shared.sd_model(x, cond)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()