aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/hypernetwork/hypernetwork.py267
-rw-r--r--modules/hypernetwork/ui.py43
-rw-r--r--modules/sd_hijack.py4
-rw-r--r--modules/textual_inversion/ui.py1
-rw-r--r--modules/ui.py58
-rw-r--r--textual_inversion_templates/hypernetwork.txt27
-rw-r--r--textual_inversion_templates/none.txt1
-rw-r--r--webui.py9
8 files changed, 401 insertions, 9 deletions
diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py
new file mode 100644
index 00000000..a3d6a47e
--- /dev/null
+++ b/modules/hypernetwork/hypernetwork.py
@@ -0,0 +1,267 @@
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+import tqdm
+
+import torch
+
+from ldm.util import default
+from modules import devices, shared, processing, sd_models
+import torch
+from torch import einsum
+from einops import rearrange, repeat
+import modules.textual_inversion.dataset
+
+
+class HypernetworkModule(torch.nn.Module):
+ def __init__(self, dim, state_dict=None):
+ super().__init__()
+
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
+
+ if state_dict is not None:
+ self.load_state_dict(state_dict, strict=True)
+ else:
+ self.linear1.weight.data.fill_(0.0001)
+ self.linear1.bias.data.fill_(0.0001)
+ self.linear2.weight.data.fill_(0.0001)
+ self.linear2.bias.data.fill_(0.0001)
+
+ self.to(devices.device)
+
+ def forward(self, x):
+ return x + (self.linear2(self.linear1(x)))
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+
+ for size in [320, 640, 768, 1280]:
+ self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+
+ def weights(self):
+ res = []
+
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+
+ return res
+
+ def save(self, filename):
+ state_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+
+ torch.save(state_dict, filename)
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+
+
+def load_hypernetworks(path):
+ res = {}
+
+ for filename in glob.iglob(path + '**/*.pt', recursive=True):
+ try:
+ hn = Hypernetwork()
+ hn.load(filename)
+ res[hn.name] = hn
+ except Exception:
+ print(f"Error loading hypernetwork {filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return res
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ hypernetwork_k, hypernetwork_v = hypernetwork_layers
+
+ self.hypernetwork_k = hypernetwork_k
+ self.hypernetwork_v = hypernetwork_v
+
+ context_k = hypernetwork_k(context)
+ context_v = hypernetwork_v(context)
+ else:
+ context_k = context
+ context_v = context
+
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+ assert hypernetwork_name, 'embedding not selected'
+
+ shared.hypernetwork = shared.hypernetworks[hypernetwork_name]
+
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hypernetwork = shared.hypernetworks[hypernetwork_name]
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ return hypernetwork, filename
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, (x, text) in pbar:
+ hypernetwork.step = i + ititial_step
+
+ if hypernetwork.step > steps:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([text])
+
+ x = x.to(devices.device)
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ del x
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ hypernetwork.save(last_saved_file)
+
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=preview_text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {hypernetwork.step}<br/>
+Last prompt: {html.escape(text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ checkpoint = sd_models.select_checkpoint()
+
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.save(filename)
+
+ return hypernetwork, filename
+
+
diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py
new file mode 100644
index 00000000..525f978c
--- /dev/null
+++ b/modules/hypernetwork/ui.py
@@ -0,0 +1,43 @@
+import html
+import os
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared
+
+
+def create_hypernetwork(name):
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
+ hypernetwork.save(fn)
+
+ shared.reload_hypernetworks()
+ shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+
+
+def train_hypernetwork(*args):
+
+ initial_hypernetwork = shared.hypernetwork
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.hypernetwork = initial_hypernetwork
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index aa4d2cbc..f873049a 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
import ldm.modules.attention
@@ -37,6 +37,8 @@ def apply_optimizations():
def undo_optimizations():
+ from modules.hypernetwork import hypernetwork
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..c57de1f9 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -22,7 +22,6 @@ def preprocess(*args):
def train_embedding(*args):
-
try:
sd_hijack.undo_optimizations()
diff --git a/modules/ui.py b/modules/ui.py
index ca3151c4..10b1ee3a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -39,6 +39,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
+import modules.hypernetwork.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1025,6 +1026,18 @@ def create_ui(wrap_gradio_gpu_call):
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
+
+ new_hypernetwork_name = gr.Textbox(label="Name")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create", variant='primary')
+
+ with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
process_src = gr.Textbox(label='Source directory')
@@ -1047,6 +1060,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1057,15 +1071,12 @@ def create_ui(wrap_gradio_gpu_call):
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
+ preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
with gr.Row():
- with gr.Column(scale=2):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_embedding = gr.Button(value="Train", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1091,6 +1102,18 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_hypernetwork.click(
+ fn=modules.hypernetwork.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
@@ -1131,12 +1154,33 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_hypernetwork_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_image_prompt,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
diff --git a/textual_inversion_templates/hypernetwork.txt b/textual_inversion_templates/hypernetwork.txt
new file mode 100644
index 00000000..91e06890
--- /dev/null
+++ b/textual_inversion_templates/hypernetwork.txt
@@ -0,0 +1,27 @@
+a photo of a [filewords]
+a rendering of a [filewords]
+a cropped photo of the [filewords]
+the photo of a [filewords]
+a photo of a clean [filewords]
+a photo of a dirty [filewords]
+a dark photo of the [filewords]
+a photo of my [filewords]
+a photo of the cool [filewords]
+a close-up photo of a [filewords]
+a bright photo of the [filewords]
+a cropped photo of a [filewords]
+a photo of the [filewords]
+a good photo of the [filewords]
+a photo of one [filewords]
+a close-up photo of the [filewords]
+a rendition of the [filewords]
+a photo of the clean [filewords]
+a rendition of a [filewords]
+a photo of a nice [filewords]
+a good photo of a [filewords]
+a photo of the nice [filewords]
+a photo of the small [filewords]
+a photo of the weird [filewords]
+a photo of the large [filewords]
+a photo of a cool [filewords]
+a photo of a small [filewords]
diff --git a/textual_inversion_templates/none.txt b/textual_inversion_templates/none.txt
new file mode 100644
index 00000000..f77af461
--- /dev/null
+++ b/textual_inversion_templates/none.txt
@@ -0,0 +1 @@
+picture
diff --git a/webui.py b/webui.py
index 270584f7..7c200551 100644
--- a/webui.py
+++ b/webui.py
@@ -77,6 +77,15 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
+def set_hypernetwork():
+ shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
+
+
+shared.reload_hypernetworks()
+shared.opts.onchange("sd_hypernetwork", set_hypernetwork)
+set_hypernetwork()
+
+
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model()