From 47033afa5c08e72b622348b0bcfd71fd1a66e2cb Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Tue, 5 Sep 2023 22:38:02 +0900 Subject: Fix preview for textual inversion training --- modules/textual_inversion/textual_inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index aa79dc09..401a0a2a 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -386,7 +386,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height): from modules import processing save_embedding_every = save_embedding_every or 0 @@ -590,7 +590,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()] p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width -- cgit v1.2.1 From a8cbe50c9fa324ed887089e4333452ecc4355c92 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 12:14:56 +0300 Subject: remove duplicated code --- modules/textual_inversion/textual_inversion.py | 74 ++++++++++++++------------ 1 file changed, 40 insertions(+), 34 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 401a0a2a..04dda585 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,40 +181,7 @@ class EmbeddingDatabase: else: return - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vectors - embedding.shape = shape - embedding.filename = path - embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') + embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): return fn +def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): + if 'string_to_param' in data: # textual inversion embeddings + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vectors + embedding.shape = shape + + if filepath: + embedding.filename = filepath + embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') + + return embedding + + def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return -- cgit v1.2.1