aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/prompt_parser.py87
-rw-r--r--modules/sd_hijack.py102
-rw-r--r--modules/shared.py3
3 files changed, 186 insertions, 6 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index a6a25b28..f3c50adf 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -126,5 +126,90 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res
+re_attention = re.compile(r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""", re.X)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ Example:
+
+ 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
+
+ produces:
+
+ [
+ ['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]
+ ]
+ """
-#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith('\\'):
+ res.append([text[1:], 1.0])
+ elif text == '(':
+ round_brackets.append(len(res))
+ elif text == '[':
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ')' and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == ']' and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ return res
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bfbd07f9..2848a251 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -6,6 +6,7 @@ import torch
import numpy as np
from torch import einsum
+from modules import prompt_parser
from modules.shared import opts, device, cmd_opts
from ldm.util import default
@@ -211,6 +212,7 @@ class StableDiffusionModelHijack:
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
@@ -236,7 +238,7 @@ class StableDiffusionModelHijack:
print(traceback.format_exc(), file=sys.stderr)
continue
- print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@@ -275,6 +277,7 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
@@ -300,7 +303,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
- def process_text(self, text):
+
+ def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ id_start = self.wrapped.tokenizer.bos_token_id
+ id_end = self.wrapped.tokenizer.eos_token_id
+ maxlen = self.wrapped.max_length
+
+ if opts.enable_emphasis:
+ parsed = prompt_parser.parse_prompt_attention(line)
+ else:
+ parsed = [[line, 1.0]]
+
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
+
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ possible_matches = self.hijack.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ else:
+ found = False
+ for ids, word in possible_matches:
+ if tokens[i:i + len(ids)] == ids:
+ emb_len = int(self.hijack.word_embeddings[word].shape[0])
+ fixes.append((len(remade_tokens), word))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ i += len(ids) - 1
+ found = True
+ used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
+ break
+
+ if not found:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ i += 1
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ return remade_tokens, fixes, multipliers, token_count
+
+ def process_text(self, texts):
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_multipliers = []
+ for line in texts:
+ if line in cache:
+ remade_tokens, fixes, multipliers = cache[line]
+ else:
+ remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+
+ cache[line] = (remade_tokens, fixes, multipliers)
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+ def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
@@ -376,12 +464,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+ if opts.use_old_emphasis_implementation:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+ else:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
diff --git a/modules/shared.py b/modules/shared.py
index ec1e569b..f88c2b02 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -195,7 +195,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),