aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-03 01:01:58 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-03 01:01:58 +0300
commit0dca0db7ebeeb2e250bf0c443f1f5521846050a4 (patch)
treec305ede3ca6118261dc314e10115a0212edce674
parent4cafad66d202433bc358d9c4b8291d593b6e4df8 (diff)
Update to support embedding with length greater than 1.
-rw-r--r--webui.bat2
-rw-r--r--webui.py15
2 files changed, 10 insertions, 7 deletions
diff --git a/webui.bat b/webui.bat
index 681aa983..5dd1f09d 100644
--- a/webui.bat
+++ b/webui.bat
@@ -7,7 +7,7 @@ set VENV_DIR=venv
mkdir tmp 2>NUL
-set TORCH_COMMAND=pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
+set TORCH_COMMAND=pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
set REQS_FILE=requirements_versions.txt
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
diff --git a/webui.py b/webui.py
index f7a52107..503cc1e4 100644
--- a/webui.py
+++ b/webui.py
@@ -746,9 +746,9 @@ class StableDiffusionModelHijack:
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1].reshape(768)
- self.word_embeddings[name] = emb
- self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
+ emb = next(iter(param_dict.items()))[1]
+ self.word_embeddings[name] = emb.detach()
+ self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
@@ -838,9 +838,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
+ emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
- remade_tokens.append(777)
- multipliers.append(mult)
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
@@ -903,7 +904,9 @@ class EmbeddingsWithFixes(nn.Module):
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
- tensor[offset] = self.embeddings.word_embeddings[word]
+ emb = self.embeddings.word_embeddings[word]
+ emb_len = min(tensor.shape[0]-offset, emb.shape[0])
+ tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
return inputs_embeds