From 0dca0db7ebeeb2e250bf0c443f1f5521846050a4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Sep 2022 01:01:58 +0300 Subject: Update to support embedding with length greater than 1. --- webui.bat | 2 +- webui.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/webui.bat b/webui.bat index 681aa983..5dd1f09d 100644 --- a/webui.bat +++ b/webui.bat @@ -7,7 +7,7 @@ set VENV_DIR=venv mkdir tmp 2>NUL -set TORCH_COMMAND=pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 +set TORCH_COMMAND=pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 set REQS_FILE=requirements_versions.txt %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt diff --git a/webui.py b/webui.py index f7a52107..503cc1e4 100644 --- a/webui.py +++ b/webui.py @@ -746,9 +746,9 @@ class StableDiffusionModelHijack: if hasattr(param_dict, '_parameters'): param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1].reshape(768) - self.word_embeddings[name] = emb - self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}' + emb = next(iter(param_dict.items()))[1] + self.word_embeddings[name] = emb.detach() + self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}' ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] @@ -838,9 +838,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): found = False for ids, word in possible_matches: if tokens[i:i+len(ids)] == ids: + emb_len = int(self.hijack.word_embeddings[word].shape[0]) fixes.append((len(remade_tokens), word)) - remade_tokens.append(777) - multipliers.append(mult) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len i += len(ids) - 1 found = True used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) @@ -903,7 +904,9 @@ class EmbeddingsWithFixes(nn.Module): if batch_fixes is not None: for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, word in fixes: - tensor[offset] = self.embeddings.word_embeddings[word] + emb = self.embeddings.word_embeddings[word] + emb_len = min(tensor.shape[0]-offset, emb.shape[0]) + tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] return inputs_embeds -- cgit v1.2.1