aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py57
1 files changed, 33 insertions, 24 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 5b30539f..5d93f7f6 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
from modules.shared import opts, device, cmd_opts
import ldm.modules.attention
@@ -18,8 +18,9 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-
def apply_optimizations():
+ undo_optimizations()
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
@@ -32,11 +33,18 @@ def apply_optimizations():
def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+def get_target_prompt_token_count(token_count):
+ if token_count < 75:
+ return 75
+
+ return math.ceil(token_count / 10) * 10
+
+
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -82,10 +90,12 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
+ def clear_comments(self):
+ self.comments = []
+
def tokenize(self, text):
- max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, max_length
+ return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
@@ -94,7 +104,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
- self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
@@ -116,7 +125,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
@@ -148,19 +156,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ prompt_target_length = get_target_prompt_token_count(token_count)
+ tokens_to_add = prompt_target_length - len(remade_tokens) + 1
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+ remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add
+ multipliers = [1.0] + multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -177,7 +178,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
- remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@@ -191,7 +193,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length
+ maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
@@ -263,17 +265,24 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
- self.hijack.comments = hijack_comments
+ self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
- tokens = torch.asarray(remade_batch_tokens).to(device)
- outputs = self.wrapped.transformer(input_ids=tokens)
+ target_token_count = get_target_prompt_token_count(token_count) + 2
+
+ position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
+ position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
+
+ remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
+ tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
+ outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(device)
+ batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
+ batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()