aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-16 09:58:07 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-12-16 09:58:07 +0300
commitcf2772fab0af5573da775e7437e6acdca424f26e (patch)
tree2ad13a0cf77bc189a8c9097bd507f9674f993da6 /modules/textual_inversion/textual_inversion.py
parent4afaaf8a020c1df457bcf7250cb1c7f609699fa7 (diff)
parent0dfffe53ec11b2ee097d55efc479f8e707015db9 (diff)
Merge branch 'release_candidate'
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py78
1 files changed, 42 insertions, 36 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index aa79dc09..04dda585 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -181,40 +181,7 @@ class EmbeddingDatabase:
else:
return
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
- vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
- shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
- vectors = data['clip_g'].shape[0]
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- shape = vec.shape[-1]
- vectors = vec.shape[0]
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vectors
- embedding.shape = shape
- embedding.filename = path
- embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
+ embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
return fn
+def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
+ if 'string_to_param' in data: # textual inversion embeddings
+ param_dict = data['string_to_param']
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vectors
+ embedding.shape = shape
+
+ if filepath:
+ embedding.filename = filepath
+ embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
+
+ return embedding
+
+
def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
@@ -386,7 +392,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
from modules import processing
save_embedding_every = save_embedding_every or 0
@@ -590,7 +596,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width