aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py30
1 files changed, 15 insertions, 15 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 984b35c4..1f8587d1 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
def apply_optimizations():
undo_optimizations()
@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
-
+
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
-
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
+
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
+
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
-
+
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
@@ -332,20 +332,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
+ z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
+
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
-
+
return z
-
-
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
-
+
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)