aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py80
1 files changed, 78 insertions, 2 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c81722a0..6d5196fe 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -9,11 +9,14 @@ from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts
+from modules.shared import opts, device, cmd_opts, aesthetic_embeddings
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+from transformers import CLIPVisionModel, CLIPModel
+import torch.optim as optim
+import copy
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -109,13 +112,29 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
+def slerp(low, high, val):
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm*high_norm).sum(1))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
+ self.clipModel = CLIPModel.from_pretrained(
+ self.wrapped.transformer.name_or_path
+ )
+ del self.clipModel.vision_model
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
+ # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
+ self.image_embs_name = None
+ self.image_embs = None
+ self.load_image_embs(None)
+
self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
@@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
+ def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
+ aesthetic_slerp=True):
+ self.slerp = aesthetic_slerp
+ self.aesthetic_lr = aesthetic_lr
+ self.aesthetic_weight = aesthetic_weight
+ self.aesthetic_steps = aesthetic_steps
+ self.load_image_embs(image_embs_name)
+
+ def load_image_embs(self, image_embs_name):
+ if image_embs_name is None or len(image_embs_name) == 0:
+ image_embs_name = None
+ if image_embs_name is not None and self.image_embs_name != image_embs_name:
+ self.image_embs_name = image_embs_name
+ self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
+ self.image_embs.requires_grad_(False)
+
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
@@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
+
+ if len(text[
+ 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [
+ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ with torch.enable_grad():
+ model = copy.deepcopy(self.clipModel).to(device)
+ model.requires_grad_(True)
+
+ # We optimize the model to maximize the similarity
+ optimizer = optim.Adam(
+ model.text_model.parameters(), lr=self.aesthetic_lr
+ )
+
+ for i in range(self.aesthetic_steps):
+ text_embs = model.get_text_features(input_ids=tokens)
+ text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ self.image_embs.T
+ loss = -sim
+ optimizer.zero_grad()
+ loss.mean().backward()
+ optimizer.step()
+
+ zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ if opts.CLIP_stop_at_last_layers > 1:
+ zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
+ zn = model.text_model.final_layer_norm(zn)
+ else:
+ zn = zn.last_hidden_state
+ model.cpu()
+ del model
+
+ if self.slerp:
+ z = slerp(z, zn, self.aesthetic_weight)
+ else:
+ z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
+
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1