aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_clip.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r--modules/sd_hijack_clip.py31
1 files changed, 27 insertions, 4 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 3b5a7666..b3771909 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75
+ self.is_trainable = getattr(wrapped, 'is_trainable', False)
+ self.input_key = getattr(wrapped, 'input_key', 'txt')
+ self.legacy_ucg_val = None
+
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
"""
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
An example shape returned by this function can be: (2, 77, 768).
+ For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
@@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
- return torch.hstack(zs)
+ if getattr(self.wrapped, 'return_pooled', False):
+ return torch.hstack(zs), zs[0].pooled
+ else:
+ return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
"""
@@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z = z * (original_mean / new_mean)
+ z *= (original_mean / new_mean)
return z
@@ -315,3 +323,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
+
+
+class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ def encode_with_transformers(self, tokens):
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
+
+ if self.wrapped.layer == "last":
+ z = outputs.last_hidden_state
+ else:
+ z = outputs.hidden_states[self.wrapped.layer_idx]
+
+ return z