From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 01:45:28 +0300 Subject: CLIP hijack rework --- modules/sd_hijack_clip_old.py | 81 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 modules/sd_hijack_clip_old.py (limited to 'modules/sd_hijack_clip_old.py') diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py new file mode 100644 index 00000000..6d9fbbe6 --- /dev/null +++ b/modules/sd_hijack_clip_old.py @@ -0,0 +1,81 @@ +from modules import sd_hijack_clip +from modules import shared + + +def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + +def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) -- cgit v1.2.1