aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py324
1 files changed, 51 insertions, 273 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa7eaeb8..fd57e5c5 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -6,244 +6,41 @@ import torch
import numpy as np
from torch import einsum
-from modules import prompt_parser
+import modules.textual_inversion.textual_inversion
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
-from ldm.util import default
-from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
+diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
+def apply_optimizations():
+ if cmd_opts.opt_split_attention_v1:
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
+def undo_optimizations():
+ ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
- s2 = s1.softmax(dim=-1)
- del s1
-
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# taken from https://github.com/Doggettx/stable-diffusion
-def split_cross_attention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
-
- del q, k, v
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
-def cross_attention_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q1 = self.q(h_)
- k1 = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q1.shape
-
- q2 = q1.reshape(b, c, h*w)
- del q1
-
- q = q2.permute(0, 2, 1) # b,hw,c
- del q2
-
- k = k1.reshape(b, c, h*w) # b,c,hw
- del k1
-
- h_ = torch.zeros_like(k, device=q.device)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
- mem_required = tensor_size * 2.5
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
-
- w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w2 = w1 * (int(c)**(-0.5))
- del w1
- w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
- del w2
-
- # attend to values
- v1 = v.reshape(b, c, h*w)
- w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- del w3
-
- h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- del v1, w4
-
- h2 = h_.reshape(b, c, h, w)
- del h_
-
- h3 = self.proj_out(h2)
- del h2
-
- h3 += x
-
- return h3
class StableDiffusionModelHijack:
- ids_lookup = {}
- word_embeddings = {}
- word_embeddings_checksums = {}
fixes = None
comments = []
- dir_mtime = None
layers = None
circular_enabled = False
clip = None
- def load_textual_inversion_embeddings(self, dirname, model):
- mt = os.path.getmtime(dirname)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
- return
-
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
-
- tokenizer = model.cond_stage_model.tokenizer
-
- def const_hash(a):
- r = 0
- for v in a:
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
- return r
-
- def process_file(path, filename):
- name = os.path.splitext(filename)[0]
-
- data = torch.load(path, map_location="cpu")
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
-
- self.word_embeddings[name] = emb.detach().to(device)
- self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
-
- ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
-
- first_id = ids[0]
- if first_id not in self.ids_lookup:
- self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, name))
-
- for fn in os.listdir(dirname):
- try:
- fullfn = os.path.join(dirname, fn)
-
- if os.stat(fullfn).st_size == 0:
- continue
-
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model
- if cmd_opts.opt_split_attention_v1:
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
- ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.hijack = hijack
+ self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
@@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
-
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
- if possible_matches is None:
+ if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i + len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
- elif possible_matches is None:
+ i += 1
+ elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i+len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(mult)
-
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
@@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
@@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids)
- if batch_fixes is not None:
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, word in fixes:
- emb = self.embeddings.word_embeddings[word]
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
+ return inputs_embeds
+
+ vecs = []
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, embedding in fixes:
+ emb = embedding.vec
+ emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+
+ vecs.append(tensor)
- return inputs_embeds
+ return torch.stack(vecs)
def add_circular_option_to_conv_2d():