aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/hypernetwork.py55
-rw-r--r--modules/sd_hijack_optimizations.py17
-rw-r--r--modules/shared.py9
3 files changed, 78 insertions, 3 deletions
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
new file mode 100644
index 00000000..9ed1eed9
--- /dev/null
+++ b/modules/hypernetwork.py
@@ -0,0 +1,55 @@
+import glob
+import os
+import torch
+from modules import devices
+
+
+class HypernetworkModule(torch.nn.Module):
+ def __init__(self, dim, state_dict):
+ super().__init__()
+
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
+
+ self.load_state_dict(state_dict, strict=True)
+ self.to(devices.device)
+
+ def forward(self, x):
+ return x + (self.linear2(self.linear1(x)))
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, filename):
+ self.filename = filename
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+ self.layers = {}
+
+ state_dict = torch.load(filename, map_location='cpu')
+ for size, sd in state_dict.items():
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+
+
+def load_hypernetworks(path):
+ res = {}
+
+ for filename in glob.iglob(path + '**/*.pt', recursive=True):
+ hn = Hypernetwork(filename)
+ res[hn.name] = hn
+
+ return res
+
+def apply(self, x, context=None, mask=None, original=None):
+
+
+ if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
+ if context.shape[1] == 77 and CrossAttention.noise_cond:
+ context = context + (torch.randn_like(context) * 0.1)
+ h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
+ k = self.to_k(h_k(context))
+ v = self.to_v(h_v(context))
+ else:
+ k = self.to_k(context)
+ v = self.to_v(context)
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index ea4cfdfc..d9cca485 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -5,6 +5,8 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
+from modules import shared
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
@@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
+
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+
+ k_in *= self.scale
+
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
diff --git a/modules/shared.py b/modules/shared.py
index 25bb6e6c..879d8424 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers
+from modules import sd_samplers, hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
config_filename = cmd_opts.ui_settings_file
+hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
+
+
+def selected_hypernetwork():
+ return hypernetworks.get(opts.sd_hypernetwork, None)
+
class State:
interrupted = False
@@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
+ "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),