aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_clip.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r--modules/sd_hijack_clip.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 8f29057a..98350ac4 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -3,7 +3,7 @@ from collections import namedtuple
import torch
-from modules import prompt_parser, devices, sd_hijack
+from modules import prompt_parser, devices, sd_hijack, sd_emphasis
from modules.shared import opts
@@ -88,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
Returns the list and the total number of tokens in the prompt.
"""
- if opts.enable_emphasis:
+ if opts.emphasis != "None":
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
@@ -249,6 +249,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
+ if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
+ self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
+
if getattr(self.wrapped, 'return_pooled', False):
return torch.hstack(zs), zs[0].pooled
else:
@@ -274,12 +277,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
pooled = getattr(z, 'pooled', None)
- # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
- original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
- new_mean = z.mean()
- z = z * (original_mean / new_mean)
+ emphasis = sd_emphasis.get_current_option(opts.emphasis)()
+ emphasis.tokens = remade_batch_tokens
+ emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
+ emphasis.z = z
+
+ emphasis.after_transformers()
+
+ z = emphasis.z
if pooled is not None:
z.pooled = pooled