aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_clip.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-12-31 13:02:28 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-12-31 13:02:28 +0300
commit3f401cdb644066fd43abf6642d2e53be53c73668 (patch)
treec5ac536017fb4bc1708e7c54385cd320bdb918a9 /modules/sd_hijack_clip.py
parentfef98723b2b1c7a9893ead41bbefcb36192babd6 (diff)
parent9a5c689c4960259f32cf627384ef5691ded5c017 (diff)
Merge remote-tracking branch 'baai-open-internal/master' into alt-diffusion
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r--modules/sd_hijack_clip.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index b451d1cf..9ea6e1ce 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -4,7 +4,7 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
-
+import modules.shared as shared
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
@@ -177,6 +177,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
+ if shared.text_model_name == "XLMR-Large":
+ return self.wrapped.encode(text)
+
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -254,7 +257,10 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+ if shared.text_model_name == "XLMR-Large":
+ self.comma_token = None
+ else :
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]