aboutsummaryrefslogtreecommitdiff
path: root/modules/prompt_parser.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-29 11:31:48 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-29 11:31:48 +0300
commitc1c27dad3ba371a5ae344b267c760aa51e77f193 (patch)
tree015617002c0a7660af6d97d96e560c1ca9b82030 /modules/prompt_parser.py
parent29ce8a687df4143f4fd7f5e22de7fb4c6baa1392 (diff)
new implementation for attention/emphasis
Diffstat (limited to 'modules/prompt_parser.py')
-rw-r--r--modules/prompt_parser.py87
1 files changed, 86 insertions, 1 deletions
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index a6a25b28..f3c50adf 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -126,5 +126,90 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res
+re_attention = re.compile(r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""", re.X)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ Example:
+
+ 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
+
+ produces:
+
+ [
+ ['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]
+ ]
+ """
-#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith('\\'):
+ res.append([text[1:], 1.0])
+ elif text == '(':
+ round_brackets.append(len(res))
+ elif text == '[':
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ')' and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == ']' and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ return res