aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-08 16:50:23 +0300
committerGitHub <noreply@github.com>2023-07-08 16:50:23 +0300
commit993dd9a8927407de8d19142cacb07e6f76686a67 (patch)
tree89df31c33ecf054c9cb70aaf38fc3382f306f450 /modules/textual_inversion/textual_inversion.py
parentff6acd35d0807a4e0c3ee86cdb1520a4a3a11cdd (diff)
parentd7d6e8cfc8b85a99a48f82975ee213d487783c28 (diff)
Merge branch 'dev' into patch-1
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py69
1 files changed, 38 insertions, 31 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4368eb63..bb6f211c 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,7 +1,4 @@
import os
-import sys
-import traceback
-import inspect
from collections import namedtuple
import torch
@@ -15,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -30,7 +27,7 @@ textual_inversion_templates = {}
def list_textual_inversion_templates():
textual_inversion_templates.clear()
- for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
for fn in fns:
path = os.path.join(root, fn)
@@ -121,16 +118,29 @@ class EmbeddingDatabase:
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):
@@ -167,8 +177,7 @@ class EmbeddingDatabase:
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
@@ -199,7 +208,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -209,14 +218,13 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
if embdir.has_changed():
need_reload = True
break
@@ -229,7 +237,7 @@ class EmbeddingDatabase:
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
self.load_from_dir(embdir)
embdir.update()
@@ -243,7 +251,7 @@ class EmbeddingDatabase:
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
- if len(self.skipped_embeddings) > 0:
+ if self.skipped_embeddings:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
@@ -325,16 +333,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- tensorboard_writer.add_scalar(tag=tag,
+ tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
# Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1))
-
+
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
@@ -404,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
-
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
@@ -414,7 +422,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
-
+
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
@@ -441,7 +449,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
-
+
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
@@ -470,7 +478,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
@@ -487,7 +495,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad:
clip_grad_sched.step(embedding.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -515,7 +523,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+
if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
@@ -603,7 +611,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
+ except Exception:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
@@ -634,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
+ errors.report("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()