aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-03 12:08:45 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-03 12:08:45 +0300
commit345028099d893f8a66726cfd13627d8cc1bcc724 (patch)
treeacb1da553620b7e7139db840ef43accf71b786a8 /modules/sd_hijack.py
parentd7b67d9b40e47ede766d3beb149b0c2b74651ece (diff)
split codebase into multiple files; to anyone this affects negatively: sorry
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py208
1 files changed, 208 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
new file mode 100644
index 00000000..6ee92e77
--- /dev/null
+++ b/modules/sd_hijack.py
@@ -0,0 +1,208 @@
+import os
+import sys
+import traceback
+import torch
+import numpy as np
+
+from modules.shared import opts, device
+
+
+class StableDiffusionModelHijack:
+ ids_lookup = {}
+ word_embeddings = {}
+ word_embeddings_checksums = {}
+ fixes = None
+ comments = []
+ dir_mtime = None
+
+ def load_textual_inversion_embeddings(self, dirname, model):
+ mt = os.path.getmtime(dirname)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ tokenizer = model.cond_stage_model.tokenizer
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = torch.load(path)
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ self.word_embeddings[name] = emb.detach()
+ self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'
+
+ ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id].append((ids, name))
+
+ for fn in os.listdir(dirname):
+ try:
+ process_file(os.path.join(dirname, fn), fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
+
+ def hijack(self, m):
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+
+class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
+ def __init__(self, wrapped, hijack):
+ super().__init__()
+ self.wrapped = wrapped
+ self.hijack = hijack
+ self.tokenizer = wrapped.tokenizer
+ self.max_length = wrapped.max_length
+ self.token_mults = {}
+
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ for text, ident in tokens_with_parens:
+ mult = 1.0
+ for c in text:
+ if c == '[':
+ mult /= 1.1
+ if c == ']':
+ mult *= 1.1
+ if c == '(':
+ mult *= 1.1
+ if c == ')':
+ mult /= 1.1
+
+ if mult != 1.0:
+ self.token_mults[ident] = mult
+
+ def forward(self, text):
+ self.hijack.fixes = []
+ self.hijack.comments = []
+ remade_batch_tokens = []
+ id_start = self.wrapped.tokenizer.bos_token_id
+ id_end = self.wrapped.tokenizer.eos_token_id
+ maxlen = self.wrapped.max_length - 2
+ used_custom_terms = []
+
+ cache = {}
+ batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ possible_matches = self.hijack.ids_lookup.get(token, None)
+
+ mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ elif possible_matches is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ else:
+ found = False
+ for ids, word in possible_matches:
+ if tokens[i:i+len(ids)] == ids:
+ emb_len = int(self.hijack.word_embeddings[word].shape[0])
+ fixes.append((len(remade_tokens), word))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ i += len(ids) - 1
+ found = True
+ used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
+ break
+
+ if not found:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+
+ i += 1
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+
+ self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ self.hijack.fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ outputs = self.wrapped.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
+ original_mean = z.mean()
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z *= original_mean / new_mean
+
+ return z
+
+
+class EmbeddingsWithFixes(torch.nn.Module):
+ def __init__(self, wrapped, embeddings):
+ super().__init__()
+ self.wrapped = wrapped
+ self.embeddings = embeddings
+
+ def forward(self, input_ids):
+ batch_fixes = self.embeddings.fixes
+ self.embeddings.fixes = None
+
+ inputs_embeds = self.wrapped(input_ids)
+
+ if batch_fixes is not None:
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, word in fixes:
+ emb = self.embeddings.word_embeddings[word]
+ emb_len = min(tensor.shape[0]-offset, emb.shape[0])
+ tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
+
+ return inputs_embeds
+
+
+model_hijack = StableDiffusionModelHijack()