import os import sys import traceback import inspect from collections import namedtuple import torch import tqdm import html import datetime import csv import safetensors.torch from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay from modules.textual_inversion.logging import save_settings_to_file TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) textual_inversion_templates = {} def list_textual_inversion_templates(): textual_inversion_templates.clear() for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): for fn in fns: path = os.path.join(root, fn) textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) return textual_inversion_templates class Embedding: def __init__(self, vec, name, step=None): self.vec = vec self.name = name self.step = step self.shape = None self.vectors = 0 self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None def save(self, filename): embedding_data = { "string_to_token": {"*": 265}, "string_to_param": {"*": self.vec}, "name": self.name, "step": self.step, "sd_checkpoint": self.sd_checkpoint, "sd_checkpoint_name": self.sd_checkpoint_name, } torch.save(embedding_data, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: optimizer_saved_dict = { 'hash': self.checksum(), 'optimizer_state_dict': self.optimizer_state_dict, } torch.save(optimizer_saved_dict, filename + '.optim') def checksum(self): if self.cached_checksum is not None: return self.cached_checksum def const_hash(a): r = 0 for v in a: r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF return r self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum class DirWithTextualInversionEmbeddings: def __init__(self, path): self.path = path self.mtime = None def has_changed(self): if not os.path.isdir(self.path): return False mt = os.path.getmtime(self.path) if self.mtime is None or mt > self.mtime: return True def update(self): if not os.path.isdir(self.path): return self.mtime = os.path.getmtime(self.path) class EmbeddingDatabase: def __init__(self): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} self.expected_shape = -1 self.embedding_dirs = {} def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) def clear_embedding_dirs(self): self.embedding_dirs.clear() def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) return embedding def get_expected_shape(self): vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] def load_from_file(self, path, filename): name, ext = os.path.splitext(filename) ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: _, second_ext = os.path.splitext(name) if second_ext.upper() == '.PREVIEW': return embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name', name) else: data = extract_image_data_embed(embed_image) name = data.get('name', name) elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") elif ext in ['.SAFETENSORS']: data = safetensors.torch.load_file(path, device="cpu") else: return # textual inversion embeddings if 'string_to_param' in data: param_dict = data['string_to_param'] if hasattr(param_dict, '_parameters'): param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] # diffuser concepts elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) else: raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) else: self.skipped_embeddings[name] = embedding def load_from_dir(self, embdir): if not os.path.isdir(embdir.path): return for root, dirs, fns in os.walk(embdir.path): for fn in fns: try: fullfn = os.path.join(root, fn) if os.stat(fullfn).st_size == 0: continue self.load_from_file(fullfn, fn) except Exception: print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue def load_textual_inversion_embeddings(self, force_reload=False): if not force_reload: need_reload = False for path, embdir in self.embedding_dirs.items(): if embdir.has_changed(): need_reload = True break if not need_reload: return self.ids_lookup.clear() self.word_embeddings.clear() self.skipped_embeddings.clear() self.expected_shape = self.get_expected_shape() for path, embdir in self.embedding_dirs.items(): self.load_from_dir(embdir) embdir.update() print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] possible_matches = self.ids_lookup.get(token, None) if possible_matches is None: return None, None for ids, embedding in possible_matches: if tokens[offset:offset + len(ids)] == ids: return embedding, len(ids) return None, None def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 embedding.save(fn) return fn def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return if step % shared.opts.training_write_csv_every != 0: return write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True with open(os.path.join(log_directory, filename), "a+", newline='') as fout: csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) if write_csv_header: csv_writer.writeheader() epoch = (step - 1) // epoch_len epoch_step = (step - 1) % epoch_len csv_writer.writerow({ "step": step, "epoch": epoch, "epoch_step": epoch_step, **values, }) def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" assert batch_size > 0, "Batch size must be positive" assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" assert gradient_step > 0, "Gradient accumulation step must be positive" assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" assert template_filename, "Prompt template file not selected" assert template_file, f"Prompt template file {template_filename} not found" assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0, "Max steps must be positive" assert isinstance(save_model_every, int), "Save {name} must be integer" assert save_model_every >= 0, "Save {name} must be positive or 0" assert isinstance(create_image_every, int), "Create image must be integer" assert create_image_every >= 0, "Create image must be positive or 0" if save_model_every or create_image_every: assert log_directory, "Log directory is empty" def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") template_file = template_file.path shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) unload = shared.opts.unload_models_when_training if save_embedding_every > 0: embedding_dir = os.path.join(log_directory, "embeddings") os.makedirs(embedding_dir, exist_ok=True) else: embedding_dir = None if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True) else: images_dir = None if create_image_every > 0 and save_image_with_stored_embedding: images_embeds_dir = os.path.join(log_directory, "image_embeddings") os.makedirs(images_embeds_dir, exist_ok=True) else: images_embeds_dir = None hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] checkpoint = sd_models.select_checkpoint() initial_step = embedding.step or 0 if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.parallel_processing_allowed = False shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) if shared.opts.save_optimizer_state: optimizer_state_dict = None if os.path.exists(filename + '.optim'): optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') if embedding.checksum() == optimizer_saved_dict.get('hash', None): optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) print("Loaded existing optimizer from checkpoint") else: print("No saved optimizer exists in checkpoint") scaler = torch.cuda.amp.GradScaler() batch_size = ds.batch_size gradient_step = ds.gradient_step # n steps = batch_size * gradient_step * n image processed steps_per_epoch = len(ds) // batch_size // gradient_step max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step loss_step = 0 _loss_step = 0 #internal last_saved_file = "" last_saved_image = "" forced_filename = "" embedding_yet_to_be_embedded = False is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} img_c = None pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break if shared.state.interrupted: break for j, batch in enumerate(dl): # works as a drop_last=True for gradient accumulation if j == max_steps_per_epoch: break scheduler.apply(optimizer, embedding.step) if scheduler.finished: break if shared.state.interrupted: break if clip_grad: clip_grad_sched.step(embedding.step) with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) if is_training_inpainting_model: if img_c is None: img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) cond = {"c_concat": [img_c], "c_crossattn": [c]} else: cond = c loss = shared.sd_model(x, cond)[0] / gradient_step del x _loss_step += loss.item() scaler.scale(loss).backward() # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue if clip_grad: clip_grad(embedding.vec, clip_grad_sched.learn_rate) scaler.step(optimizer) scaler.update() embedding.step += 1 pbar.update() optimizer.zero_grad(set_to_none=True) loss_step = _loss_step _loss_step = 0 steps_done = embedding.step + 1 epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { "loss": f"{loss_step:.7f}", "learn_rate": scheduler.learn_rate }) if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{embedding_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) shared.sd_model.first_stage_model.to(devices.device) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, do_not_save_grid=True, do_not_save_samples=True, do_not_reload_embeddings=True, ) if preview_from_txt2img: p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width p.height = preview_height else: p.prompt = batch.cond_text[0] p.steps = 20 p.width = training_width p.height = training_height preview_text = p.prompt processed = processing.process_images(p) image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.first_stage_model.to(devices.cpu) if image is not None: shared.state.current_image = image last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embedding_to_b64(data)) title = "<{}>".format(data.get('name', '???')) try: vectorSize = list(data['string_to_param'].values())[0].shape[0] except Exception as e: vectorSize = '?' checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) embedding_yet_to_be_embedded = False last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step shared.state.textinfo = f"""

Loss: {loss_step:.7f}
Step: {steps_done}
Last prompt: {html.escape(batch.cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: print(traceback.format_exc(), file=sys.stderr) pass finally: pbar.leave = False pbar.close() shared.sd_model.first_stage_model.to(devices.device) shared.parallel_processing_allowed = old_parallel_processing_allowed return embedding, filename def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None try: embedding.sd_checkpoint = checkpoint.hash embedding.sd_checkpoint_name = checkpoint.model_name if remove_cached_checksum: embedding.cached_checksum = None embedding.name = embedding_name embedding.optimizer_state_dict = optimizer.state_dict() embedding.save(filename) except: embedding.sd_checkpoint = old_sd_checkpoint embedding.sd_checkpoint_name = old_sd_checkpoint_name embedding.name = old_embedding_name embedding.cached_checksum = old_cached_checksum raise