From 5034f7d7597685aaa4779296983be0f49f4f991f Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 15:56:18 -0400 Subject: added token counter next to txt2img and img2img prompts --- modules/sd_hijack.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7b2030d4..4d799ac0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,6 +180,7 @@ class StableDiffusionModelHijack: dir_mtime = None layers = None circular_enabled = False + clip = None def load_textual_inversion_embeddings(self, dirname, model): mt = os.path.getmtime(dirname) @@ -242,6 +243,7 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + self.clip = m.cond_stage_model if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 @@ -268,6 +270,11 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def tokenize(self, text): + max_length = self.clip.max_length - 2 + _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): @@ -294,14 +301,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def forward(self, text): - self.hijack.fixes = [] - self.hijack.comments = [] - remade_batch_tokens = [] + def process_text(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length used_custom_terms = [] + remade_batch_tokens = [] + overflowing_words = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] @@ -353,9 +362,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) @@ -364,8 +372,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) - self.hijack.fixes.append(fixes) + hijack_fixes.append(fixes) batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + self.hijack.fixes = hijack_fixes + self.hijack.comments = hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) -- cgit v1.2.1