aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-27 11:17:55 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-27 11:17:55 +0300
commit9597b265ec07e8ec6dab7487152459046585c1f9 (patch)
tree5eafd5fc597851b68b0beeeeba800916f7f91858
parenta51bedfb5ae2e1adeb7406b183305b2ea7530eac (diff)
implementation for attention using [] and ()
-rw-r--r--README.md6
-rw-r--r--images/attention-3.jpgbin0 -> 966484 bytes
-rw-r--r--webui.py79
3 files changed, 62 insertions, 23 deletions
diff --git a/README.md b/README.md
index 0c49d6f2..63f8d000 100644
--- a/README.md
+++ b/README.md
@@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
were commandline. Settings are saved to config.js file. Settings that remain as commandline
options are ones that are required at startup.
+
+### Attention
+Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
+multiple modifiers:
+
+![](images/attention-3.jpg)
diff --git a/images/attention-3.jpg b/images/attention-3.jpg
new file mode 100644
index 00000000..7c7ef0d3
--- /dev/null
+++ b/images/attention-3.jpg
Binary files differ
diff --git a/webui.py b/webui.py
index b3375e98..a0fa23c4 100644
--- a/webui.py
+++ b/webui.py
@@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
print(traceback.format_exc(), file=sys.stderr)
-class TextInversionEmbeddings:
+class StableDiffuionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
- fixes = []
+ fixes = None
used_custom_terms = []
dir_mtime = None
- def load(self, dir, model):
+ def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
@@ -469,6 +469,7 @@ class TextInversionEmbeddings:
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
+
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
@@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
+ self.token_mults = {}
+
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ for text, ident in tokens_with_parens:
+ mult = 1.0
+ for c in text:
+ if c == '[':
+ mult /= 1.1
+ if c == ']':
+ mult *= 1.1
+ if c == '(':
+ mult *= 1.1
+ if c == ')':
+ mult /= 1.1
+
+ if mult != 1.0:
+ self.token_mults[ident] = mult
def forward(self, text):
self.embeddings.fixes = []
@@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
+ batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
- remade_tokens, fixes = cache[tuple_tokens]
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
+ multipliers = []
+ mult = 1.0
i = 0
while i < len(tokens):
@@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
possible_matches = self.embeddings.ids_lookup.get(token, None)
- if possible_matches is None:
+ mult_change = self.token_mults.get(token)
+ if mult_change is not None:
+ mult *= mult_change
+ elif possible_matches is None:
remade_tokens.append(token)
+ multipliers.append(mult)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word))
remade_tokens.append(777)
+ multipliers.append(mult)
i += len(ids) - 1
found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
@@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if not found:
remade_tokens.append(token)
+ multipliers.append(mult)
i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes)
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes)
+ batch_multipliers.append(multipliers)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
+
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
+ original_mean = z.mean()
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z *= original_mean / new_mean
+
return z
@@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module):
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
- self.embeddings.fixes = []
+ self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, word in fixes:
- tensor[offset] = self.embeddings.word_embeddings[word]
-
- return inputs_embeds
+ if batch_fixes is not None:
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, word in fixes:
+ tensor[offset] = self.embeddings.word_embeddings[word]
-def get_learned_conditioning_with_embeddings(model, prompts):
- if os.path.exists(cmd_opts.embeddings_dir):
- text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
-
- return model.get_learned_conditioning(prompts)
+ return inputs_embeds
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
@@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir):
- text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
+ model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope():
@@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts)
- if len(text_inversion_embeddings.used_custom_terms) > 0:
- comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
+ if len(model_hijack.used_custom_terms) > 0:
+ comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
@@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device)
-text_inversion_embeddings = TextInversionEmbeddings()
-if os.path.exists(cmd_opts.embeddings_dir):
- text_inversion_embeddings.hijack(model)
+model_hijack = StableDiffuionModelHijack()
+model_hijack.hijack(model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],