aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/hypernetwork.py98
-rw-r--r--modules/hypernetworks/hypernetwork.py305
-rw-r--r--modules/hypernetworks/ui.py47
-rw-r--r--modules/ngrok.py15
-rw-r--r--modules/safe.py17
-rw-r--r--modules/sd_hijack.py34
-rw-r--r--modules/sd_hijack_optimizations.py140
-rw-r--r--modules/sd_samplers.py24
-rw-r--r--modules/shared.py25
-rw-r--r--modules/textual_inversion/dataset.py36
-rw-r--r--modules/textual_inversion/learn_schedule.py34
-rw-r--r--modules/textual_inversion/preprocess.py5
-rw-r--r--modules/textual_inversion/textual_inversion.py41
-rw-r--r--modules/textual_inversion/ui.py2
-rw-r--r--modules/ui.py108
15 files changed, 736 insertions, 195 deletions
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
deleted file mode 100644
index 498bc9d8..00000000
--- a/modules/hypernetwork.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import glob
-import os
-import sys
-import traceback
-
-import torch
-
-from ldm.util import default
-from modules import devices, shared
-import torch
-from torch import einsum
-from einops import rearrange, repeat
-
-
-class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, state_dict):
- super().__init__()
-
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
-
- self.load_state_dict(state_dict, strict=True)
- self.to(devices.device)
-
- def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, filename):
- self.filename = filename
- self.name = os.path.splitext(os.path.basename(filename))[0]
- self.layers = {}
-
- state_dict = torch.load(filename, map_location='cpu')
- for size, sd in state_dict.items():
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
- name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
- return res
-
-
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- if path is not None:
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork(path)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
-
- shared.loaded_hypernetwork = None
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k = self.to_k(hypernetwork_layers[0](context))
- v = self.to_v(hypernetwork_layers[1](context))
- else:
- k = self.to_k(context)
- v = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
new file mode 100644
index 00000000..470659df
--- /dev/null
+++ b/modules/hypernetworks/hypernetwork.py
@@ -0,0 +1,305 @@
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+import tqdm
+
+import torch
+
+from ldm.util import default
+from modules import devices, shared, processing, sd_models
+import torch
+from torch import einsum
+from einops import rearrange, repeat
+import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
+
+
+class HypernetworkModule(torch.nn.Module):
+ def __init__(self, dim, state_dict=None):
+ super().__init__()
+
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
+
+ if state_dict is not None:
+ self.load_state_dict(state_dict, strict=True)
+ else:
+
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear1.bias.data.zero_()
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear2.bias.data.zero_()
+
+ self.to(devices.device)
+
+ def forward(self, x):
+ return x + (self.linear2(self.linear1(x)))
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None, enable_sizes=None):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+
+ for size in enable_sizes or []:
+ self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+
+ def weights(self):
+ res = []
+
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+
+ return res
+
+ def save(self, filename):
+ state_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+
+ torch.save(state_dict, filename)
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+
+
+def list_hypernetworks(path):
+ res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
+
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
+ try:
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
+
+ shared.loaded_hypernetwork = None
+
+
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+ assert hypernetwork_name, 'embedding not selected'
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ return hypernetwork, filename
+
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ for i, (x, text, cond) in pbar:
+ hypernetwork.step = i + ititial_step
+
+ if hypernetwork.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except Exception:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ cond = cond.to(devices.device)
+ x = x.to(devices.device)
+ loss = shared.sd_model(x.unsqueeze(0), cond)[0]
+ del x
+ del cond
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ hypernetwork.save(last_saved_file)
+
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=preview_text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {hypernetwork.step}<br/>
+Last prompt: {html.escape(text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ checkpoint = sd_models.select_checkpoint()
+
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.save(filename)
+
+ return hypernetwork, filename
+
+
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
new file mode 100644
index 00000000..dfa599af
--- /dev/null
+++ b/modules/hypernetworks/ui.py
@@ -0,0 +1,47 @@
+import html
+import os
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared, devices
+from modules.hypernetworks import hypernetwork
+
+
+def create_hypernetwork(name, enable_sizes):
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+
+
+def train_hypernetwork(*args):
+
+ initial_hypernetwork = shared.loaded_hypernetwork
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/ngrok.py b/modules/ngrok.py
new file mode 100644
index 00000000..7d03a6df
--- /dev/null
+++ b/modules/ngrok.py
@@ -0,0 +1,15 @@
+from pyngrok import ngrok, conf, exception
+
+
+def connect(token, port):
+ if token == None:
+ token = 'None'
+ conf.get_default().auth_token = token
+ try:
+ public_url = ngrok.connect(port).public_url
+ except exception.PyngrokNgrokError:
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
+ else:
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
+ 'You can use this link after the launch is complete.')
diff --git a/modules/safe.py b/modules/safe.py
index 05917463..20be16a5 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -10,6 +10,7 @@ import torch
import numpy
import _codecs
import zipfile
+import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
@@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler):
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+allowed_zip_names = ["archive/data.pkl", "archive/version"]
+allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
+
+
+def check_zip_filenames(filename, names):
+ for name in names:
+ if name in allowed_zip_names:
+ continue
+ if allowed_zip_names_re.match(name):
+ continue
+
+ raise Exception(f"bad file inside {filename}: {name}")
+
+
def check_pt(filename):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
+ check_zip_filenames(filename, z.namelist())
+
with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
unpickler.load()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 827bf304..ac70f876 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -8,8 +8,9 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
+from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -30,13 +31,23 @@ def apply_optimizations():
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ if not invokeAI_mps_available and shared.device.type == 'mps':
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print("Applying v1 cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ else:
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization.")
+ print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
def undo_optimizations():
+ from modules.hypernetworks import hypernetwork
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@@ -107,6 +118,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.tokenizer = wrapped.tokenizer
self.token_mults = {}
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
@@ -136,6 +149,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes = []
remade_tokens = []
multipliers = []
+ last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
@@ -144,6 +158,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [id_end] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -284,7 +312,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
-
+
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 18408e62..79405525 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,7 @@
import math
import sys
import traceback
+import importlib
import torch
from torch import einsum
@@ -9,6 +10,8 @@ from ldm.util import default
from einops import rearrange
from modules import shared
+from modules.hypernetworks import hypernetwork
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -26,16 +29,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
- del context, x
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
@@ -59,22 +56,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2)
-# taken from https://github.com/Doggettx/stable-diffusion
+# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
k_in *= self.scale
@@ -126,18 +117,111 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+
+def check_for_psutil():
+ try:
+ spec = importlib.util.find_spec('psutil')
+ return spec is not None
+ except ModuleNotFoundError:
+ return False
+
+invokeAI_mps_available = check_for_psutil()
+
+# -- Taken from https://github.com/invoke-ai/InvokeAI --
+if invokeAI_mps_available:
+ import psutil
+ mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
+def einsum_op_compvis(q, k, v):
+ s = einsum('b i d, b j d -> b i j', q, k)
+ s = s.softmax(dim=-1, dtype=s.dtype)
+ return einsum('b i j, b j d -> b i d', s, v)
+
+def einsum_op_slice_0(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], slice_size):
+ end = i + slice_size
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
+ return r
+
+def einsum_op_slice_1(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
+ return r
+
+def einsum_op_mps_v1(q, k, v):
+ if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ return einsum_op_compvis(q, k, v)
+ else:
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ return einsum_op_slice_1(q, k, v, slice_size)
+
+def einsum_op_mps_v2(q, k, v):
+ if mem_total_gb > 8 and q.shape[1] <= 4096:
+ return einsum_op_compvis(q, k, v)
+ else:
+ return einsum_op_slice_0(q, k, v, 1)
+
+def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
+ if size_mb <= max_tensor_mb:
+ return einsum_op_compvis(q, k, v)
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
+ if div <= q.shape[0]:
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
+def einsum_op_cuda(q, k, v):
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ # Divide factor of safety as there's copying and fragmentation
+ return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
+def einsum_op(q, k, v):
+ if q.device.type == 'cuda':
+ return einsum_op_cuda(q, k, v)
+
+ if q.device.type == 'mps':
+ if mem_total_gb >= 32:
+ return einsum_op_mps_v1(q, k, v)
+ return einsum_op_mps_v2(q, k, v)
+
+ # Smaller slices are faster due to L2/L3/SLC caches.
+ # Tested on i7 with 8MB L3 cache.
+ return einsum_op_tensor_mem(q, k, v, 32)
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k) * self.scale
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
+
+# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index d168b938..20309e06 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -57,7 +57,7 @@ def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
- hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
+ hidden_img2img = set(opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
@@ -365,16 +365,26 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
- noise = noise * sigmas[steps - t_enc - 1]
- xi = x + noise
-
- extra_params_kwargs = self.initialize(p)
-
sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+
+ extra_params_kwargs = self.initialize(p)
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+ return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
diff --git a/modules/shared.py b/modules/shared.py
index 99a0264c..817203f8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,8 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, hypernetwork
+from modules import sd_samplers
+from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@@ -36,6 +38,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@@ -47,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
@@ -82,10 +86,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
-hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
+hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+def reload_hypernetworks():
+ global hypernetworks
+
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+ hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
+
class State:
skipped = False
interrupted = False
@@ -217,6 +228,10 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"),
+}))
+
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
@@ -227,6 +242,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
+ "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
@@ -239,6 +255,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
}))
options_templates.update(options_section(('ui', "User interface"), {
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index bcf772d2..f61f40d3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,14 +8,14 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
self.placeholder_token = placeholder_token
@@ -32,12 +32,15 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
@@ -52,7 +55,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ if include_cond:
+ text = self.create_text(filename_tokens)
+ cond = cond_model([text]).to(devices.cpu)
+ else:
+ cond = None
+
+ self.dataset.append((init_latent, filename_tokens, cond))
self.length = len(self.dataset) * repeats
@@ -63,6 +72,12 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ def create_text(self, filename_tokens):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+ return text
+
def __len__(self):
return self.length
@@ -71,10 +86,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
-
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ x, filename_tokens, cond = self.dataset[index]
- return x, text
+ text = self.create_text(filename_tokens)
+ return x, text, cond
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..db720271
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,34 @@
+
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index c0af729b..a96388d6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -60,7 +60,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
if shared.state.interrupted:
break
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..7717837d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,6 +10,7 @@ import datetime
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class Embedding:
@@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -189,8 +190,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
@@ -200,15 +199,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
+ for i, (x, text, _) in pbar:
embedding.step = i + ititial_step
- if embedding.step > steps:
- break
+ if embedding.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
@@ -226,10 +234,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward()
optimizer.step()
- epoch_num = embedding.step // epoch_len
- epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -238,12 +246,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
+ prompt=preview_text,
steps=20,
- height=training_height,
- width=training_width,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
@@ -254,7 +264,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
@@ -276,4 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.save(filename)
return embedding, filename
-
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..36881e7a 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -23,6 +23,8 @@ def preprocess(*args):
def train_embedding(*args):
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
try:
sd_hijack.undo_optimizations()
diff --git a/modules/ui.py b/modules/ui.py
index 2ad7d864..2891fc8c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -39,6 +39,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
+import modules.hypernetworks.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -50,6 +51,11 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+if cmd_opts.ngrok != None:
+ import modules.ngrok as ngrok
+ print('ngrok authtoken detected, trying to connect...')
+ ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860)
+
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
@@ -311,7 +317,7 @@ def interrogate(image):
def interrogate_deepbooru(image):
- prompt = get_deepbooru_tags(image)
+ prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold)
return gr_show(True) if prompt is None else prompt
@@ -428,7 +434,10 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=8):
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Column(scale=1, elem_id="roll_col"):
+ sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@@ -549,15 +558,15 @@ def create_ui(wrap_gradio_gpu_call):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -737,15 +746,15 @@ def create_ui(wrap_gradio_gpu_call):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -1022,7 +1031,20 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary')
+
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
+
+ new_hypernetwork_name = gr.Textbox(label="Name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
@@ -1051,7 +1073,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
- learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
+ learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1061,15 +1084,12 @@ def create_ui(wrap_gradio_gpu_call):
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
+ preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
with gr.Row():
- with gr.Column(scale=2):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_embedding = gr.Button(value="Train", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1095,6 +1115,19 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_hypernetwork.click(
+ fn=modules.hypernetworks.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ new_hypernetwork_sizes,
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
@@ -1129,6 +1162,27 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every,
save_embedding_every,
template_file,
+ preview_image_prompt,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_hypernetwork_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_image_prompt,
],
outputs=[
ti_output,
@@ -1142,6 +1196,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1295,6 +1350,7 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
+
restart_gradio.click(
fn=request_restart,
inputs=[],
@@ -1336,7 +1392,7 @@ Requested path was: {f}
with gr.Tabs() as tabs:
for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid):
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
if os.path.exists(os.path.join(script_path, "notification.mp3")):