aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_clip.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-04 00:40:16 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-01-06 00:14:20 -0500
commitf6ab5a39d762a7791573d1c52ae5a3024b10e8ed (patch)
treec3958d77a6dae42457b571dbe0f1efec7ce45dd2 /modules/sd_hijack_clip.py
parentd782a95967c9eea753df3333cd1954b6ec73eba0 (diff)
parent3e22e294135ed0327ce9d9738655ff03c53df3c0 (diff)
Merge branch 'AUTOMATIC1111:master' into sub-quad_attn_opt
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r--modules/sd_hijack_clip.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index b451d1cf..ca92b142 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -5,7 +5,6 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
-
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
@@ -254,10 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@@ -296,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded