From 5034f7d7597685aaa4779296983be0f49f4f991f Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 15:56:18 -0400 Subject: added token counter next to txt2img and img2img prompts --- modules/sd_hijack.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7b2030d4..4d799ac0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,6 +180,7 @@ class StableDiffusionModelHijack: dir_mtime = None layers = None circular_enabled = False + clip = None def load_textual_inversion_embeddings(self, dirname, model): mt = os.path.getmtime(dirname) @@ -242,6 +243,7 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + self.clip = m.cond_stage_model if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 @@ -268,6 +270,11 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def tokenize(self, text): + max_length = self.clip.max_length - 2 + _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): @@ -294,14 +301,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def forward(self, text): - self.hijack.fixes = [] - self.hijack.comments = [] - remade_batch_tokens = [] + def process_text(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length used_custom_terms = [] + remade_batch_tokens = [] + overflowing_words = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] @@ -353,9 +362,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) @@ -364,8 +372,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) - self.hijack.fixes.append(fixes) + hijack_fixes.append(fixes) batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + self.hijack.fixes = hijack_fixes + self.hijack.comments = hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) -- cgit v1.2.1 From e5707b66d6db2c019bfccf66f9ba53e3daaea40b Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 19:29:53 -0400 Subject: switched the token counter to use hidden buttons instead of api call --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4d799ac0..bfbd07f9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -273,8 +273,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = self.clip.max_length - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} - + return remade_batch_tokens[0], token_count, max_length class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): -- cgit v1.2.1 From c1c27dad3ba371a5ae344b267c760aa51e77f193 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 29 Sep 2022 11:31:48 +0300 Subject: new implementation for attention/emphasis --- modules/sd_hijack.py | 102 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bfbd07f9..2848a251 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,6 +6,7 @@ import torch import numpy as np from torch import einsum +from modules import prompt_parser from modules.shared import opts, device, cmd_opts from ldm.util import default @@ -211,6 +212,7 @@ class StableDiffusionModelHijack: param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] + # diffuser concepts elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: assert len(data.keys()) == 1, 'embedding file has multiple terms in it' @@ -236,7 +238,7 @@ class StableDiffusionModelHijack: print(traceback.format_exc(), file=sys.stderr) continue - print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -275,6 +277,7 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() @@ -300,7 +303,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def process_text(self, text): + + def tokenize_line(self, line, used_custom_terms, hijack_comments): + id_start = self.wrapped.tokenizer.bos_token_id + id_end = self.wrapped.tokenizer.eos_token_id + maxlen = self.wrapped.max_length + + if opts.enable_emphasis: + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + + fixes = [] + remade_tokens = [] + multipliers = [] + + for tokens, (text, weight) in zip(tokenized, parsed): + i = 0 + while i < len(tokens): + token = tokens[i] + + possible_matches = self.hijack.ids_lookup.get(token, None) + + if possible_matches is None: + remade_tokens.append(token) + multipliers.append(weight) + else: + found = False + for ids, word in possible_matches: + if tokens[i:i + len(ids)] == ids: + emb_len = int(self.hijack.word_embeddings[word].shape[0]) + fixes.append((len(remade_tokens), word)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + i += len(ids) - 1 + found = True + used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) + break + + if not found: + remade_tokens.append(token) + multipliers.append(weight) + i += 1 + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + return remade_tokens, fixes, multipliers, token_count + + def process_text(self, texts): + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_multipliers = [] + for line in texts: + if line in cache: + remade_tokens, fixes, multipliers = cache[line] + else: + remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + + cache[line] = (remade_tokens, fixes, multipliers) + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + + def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length @@ -376,12 +464,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + if opts.use_old_emphasis_implementation: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + else: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens) -- cgit v1.2.1 From c715ef04d1edb1a112a602639ed3bb292fdeb0e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 29 Sep 2022 15:40:28 +0300 Subject: fix for incorrect model weight loading for #814 --- modules/sd_hijack.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2848a251..5945b7c2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -245,6 +245,7 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + self.clip = m.cond_stage_model if cmd_opts.opt_split_attention_v1: @@ -263,6 +264,14 @@ class StableDiffusionModelHijack: self.layers = flatten(m) + def undo_hijack(self, m): + if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: + m.cond_stage_model = m.cond_stage_model.wrapped + + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: + model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + def apply_circular(self, enable): if self.circular_enabled == enable: return -- cgit v1.2.1