aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--javascript/progressbar.js1
-rw-r--r--javascript/textualInversion.js8
-rw-r--r--modules/devices.py3
-rw-r--r--modules/processing.py13
-rw-r--r--modules/sd_hijack.py324
-rw-r--r--modules/sd_hijack_optimizations.py164
-rw-r--r--modules/sd_models.py4
-rw-r--r--modules/shared.py3
-rw-r--r--modules/textual_inversion/dataset.py76
-rw-r--r--modules/textual_inversion/textual_inversion.py258
-rw-r--r--modules/textual_inversion/ui.py32
-rw-r--r--modules/ui.py139
-rw-r--r--style.css10
-rw-r--r--textual_inversion_templates/style.txt19
-rw-r--r--textual_inversion_templates/style_filewords.txt19
-rw-r--r--textual_inversion_templates/subject.txt27
-rw-r--r--textual_inversion_templates/subject_filewords.txt27
-rw-r--r--webui.py15
19 files changed, 828 insertions, 315 deletions
diff --git a/.gitignore b/.gitignore
index 3532dab3..7afc9395 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,3 +25,4 @@ __pycache__
/.idea
notification.mp3
/SwinIR
+/textual_inversion
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 21f25b38..1e297abb 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
+ check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
new file mode 100644
index 00000000..8061be08
--- /dev/null
+++ b/javascript/textualInversion.js
@@ -0,0 +1,8 @@
+
+
+function start_training_textual_inversion(){
+ requestProgress('ti')
+ gradioApp().querySelector('#ti_error').innerHTML=''
+
+ return args_to_array(arguments)
+}
diff --git a/modules/devices.py b/modules/devices.py
index 07bb2339..ff82f2f6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-
device = get_optimal_device()
device_codeformer = cpu if has_mps else device
-
+dtype = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
diff --git a/modules/processing.py b/modules/processing.py
index 7eeb5191..8223423a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.styles: str = styles
+ self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
- "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p)
- os.makedirs(p.outpath_samples, exist_ok=True)
- os.makedirs(p.outpath_grids, exist_ok=True)
+ if p.outpath_samples is not None:
+ os.makedirs(p.outpath_samples, exist_ok=True)
+
+ if p.outpath_grids is not None:
+ os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
- model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa7eaeb8..fd57e5c5 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -6,244 +6,41 @@ import torch
import numpy as np
from torch import einsum
-from modules import prompt_parser
+import modules.textual_inversion.textual_inversion
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
-from ldm.util import default
-from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
+diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
+def apply_optimizations():
+ if cmd_opts.opt_split_attention_v1:
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
+def undo_optimizations():
+ ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
- s2 = s1.softmax(dim=-1)
- del s1
-
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# taken from https://github.com/Doggettx/stable-diffusion
-def split_cross_attention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
-
- del q, k, v
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
-def cross_attention_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q1 = self.q(h_)
- k1 = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q1.shape
-
- q2 = q1.reshape(b, c, h*w)
- del q1
-
- q = q2.permute(0, 2, 1) # b,hw,c
- del q2
-
- k = k1.reshape(b, c, h*w) # b,c,hw
- del k1
-
- h_ = torch.zeros_like(k, device=q.device)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
- mem_required = tensor_size * 2.5
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
-
- w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w2 = w1 * (int(c)**(-0.5))
- del w1
- w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
- del w2
-
- # attend to values
- v1 = v.reshape(b, c, h*w)
- w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- del w3
-
- h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- del v1, w4
-
- h2 = h_.reshape(b, c, h, w)
- del h_
-
- h3 = self.proj_out(h2)
- del h2
-
- h3 += x
-
- return h3
class StableDiffusionModelHijack:
- ids_lookup = {}
- word_embeddings = {}
- word_embeddings_checksums = {}
fixes = None
comments = []
- dir_mtime = None
layers = None
circular_enabled = False
clip = None
- def load_textual_inversion_embeddings(self, dirname, model):
- mt = os.path.getmtime(dirname)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
- return
-
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
-
- tokenizer = model.cond_stage_model.tokenizer
-
- def const_hash(a):
- r = 0
- for v in a:
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
- return r
-
- def process_file(path, filename):
- name = os.path.splitext(filename)[0]
-
- data = torch.load(path, map_location="cpu")
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
-
- self.word_embeddings[name] = emb.detach().to(device)
- self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
-
- ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
-
- first_id = ids[0]
- if first_id not in self.ids_lookup:
- self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, name))
-
- for fn in os.listdir(dirname):
- try:
- fullfn = os.path.join(dirname, fn)
-
- if os.stat(fullfn).st_size == 0:
- continue
-
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model
- if cmd_opts.opt_split_attention_v1:
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
- ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.hijack = hijack
+ self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
@@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
-
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
- if possible_matches is None:
+ if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i + len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
- elif possible_matches is None:
+ i += 1
+ elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i+len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(mult)
-
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
@@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
@@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids)
- if batch_fixes is not None:
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, word in fixes:
- emb = self.embeddings.word_embeddings[word]
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
+ return inputs_embeds
+
+ vecs = []
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, embedding in fixes:
+ emb = embedding.vec
+ emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+
+ vecs.append(tensor)
- return inputs_embeds
+ return torch.stack(vecs)
def add_circular_option_to_conv_2d():
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
new file mode 100644
index 00000000..9c079e57
--- /dev/null
+++ b/modules/sd_hijack_optimizations.py
@@ -0,0 +1,164 @@
+import math
+import torch
+from torch import einsum
+
+from ldm.util import default
+from einops import rearrange
+
+
+# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
+def split_cross_attention_forward_v1(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+
+# taken from https://github.com/Doggettx/stable-diffusion
+def split_cross_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q_in = self.to_q(x)
+ context = default(context, x)
+ k_in = self.to_k(context) * self.scale
+ v_in = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+
+ del q, k, v
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+def nonlinearity_hijack(x):
+ # swish
+ t = torch.sigmoid(x)
+ x *= t
+ del t
+
+ return x
+
+def cross_attention_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q1.shape
+
+ q2 = q1.reshape(b, c, h*w)
+ del q1
+
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
+
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 2539f14c..5b3dbdc7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader
+from modules import shared, modelloader, devices
from modules.paths import models_path
model_dir = "Stable-diffusion"
@@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half:
model.half()
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+
model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file
diff --git a/modules/shared.py b/modules/shared.py
index ac968b2d..ac0bc480 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -78,6 +78,7 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ textinfo = None
def interrupt(self):
self.interrupted = True
@@ -88,7 +89,7 @@ class State:
self.current_image_sampling_step = 0
def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State()
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
new file mode 100644
index 00000000..7e134a08
--- /dev/null
+++ b/modules/textual_inversion/dataset.py
@@ -0,0 +1,76 @@
+import os
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+import random
+import tqdm
+
+
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+
+ self.placeholder_token = placeholder_token
+
+ self.size = size
+ self.width = width
+ self.height = height
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ self.dataset = []
+
+ with open(template_file, "r") as file:
+ lines = [x.strip() for x in file.readlines()]
+
+ self.lines = lines
+
+ assert data_root, 'dataset directory not specified'
+
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ print("Preparing dataset...")
+ for path in tqdm.tqdm(self.image_paths):
+ image = Image.open(path)
+ image = image.convert('RGB')
+ image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+
+ filename = os.path.basename(path)
+ filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
+ filename_tokens = [token for token in filename_tokens if token.isalpha()]
+
+ npimage = np.array(image).astype(np.uint8)
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
+
+ torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
+ torchdata = torch.moveaxis(torchdata, 2, 0)
+
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+
+ self.dataset.append((init_latent, filename_tokens))
+
+ self.length = len(self.dataset) * repeats
+
+ self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.indexes = None
+ self.shuffle()
+
+ def shuffle(self):
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ if i % len(self.dataset) == 0:
+ self.shuffle()
+
+ index = self.indexes[i % len(self.indexes)]
+ x, filename_tokens = self.dataset[index]
+
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+
+ return x, text
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
new file mode 100644
index 00000000..c0baaace
--- /dev/null
+++ b/modules/textual_inversion/textual_inversion.py
@@ -0,0 +1,258 @@
+import os
+import sys
+import traceback
+
+import torch
+import tqdm
+import html
+import datetime
+
+from modules import shared, devices, sd_hijack, processing
+import modules.textual_inversion.dataset
+
+
+class Embedding:
+ def __init__(self, vec, name, step=None):
+ self.vec = vec
+ self.name = name
+ self.step = step
+ self.cached_checksum = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ }
+
+ torch.save(embedding_data, filename)
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+class EmbeddingDatabase:
+ def __init__(self, embeddings_dir):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.dir_mtime = None
+ self.embeddings_dir = embeddings_dir
+
+ def register_embedding(self, embedding, model):
+
+ self.word_embeddings[embedding.name] = embedding
+
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id].append((ids, embedding))
+
+ return embedding
+
+ def load_textual_inversion_embeddings(self):
+ mt = os.path.getmtime(self.embeddings_dir)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = torch.load(path, map_location="cpu")
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ self.register_embedding(embedding, shared.sd_model)
+
+ for fn in os.listdir(self.embeddings_dir):
+ try:
+ fullfn = os.path.join(self.embeddings_dir, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ return None
+
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding
+
+ return None
+
+
+
+def create_embedding(name, num_vectors_per_token):
+ init_text = '*'
+
+ cond_model = shared.sd_model.cond_stage_model
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ embedding = Embedding(vec, name)
+ embedding.step = 0
+ embedding.save(fn)
+
+ return fn
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+ assert embedding_name, 'embedding not selected'
+
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ embedding.vec.requires_grad = True
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ return embedding, filename
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, (x, text) in pbar:
+ embedding.step = i + ititial_step
+
+ if embedding.step > steps:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([text])
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ embedding.save(last_saved_file)
+
+ if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {embedding.step}<br/>
+Last prompt: {html.escape(text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ embedding.cached_checksum = None
+ embedding.save(filename)
+
+ return embedding, filename
+
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
new file mode 100644
index 00000000..ce3677a9
--- /dev/null
+++ b/modules/textual_inversion/ui.py
@@ -0,0 +1,32 @@
+import html
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion as ti
+from modules import sd_hijack, shared
+
+
+def create_embedding(name, nvpt):
+ filename = ti.create_embedding(name, nvpt)
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+
+
+def train_embedding(*args):
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ embedding, filename = ti.train_embedding(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
+Embedding saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ sd_hijack.apply_optimizations()
diff --git a/modules/ui.py b/modules/ui.py
index 15572bb0..57aef6ff 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -21,6 +21,7 @@ import gradio as gr
import gradio.utils
import gradio.routes
+from modules import sd_hijack
from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -32,6 +33,7 @@ import modules.gfpgan_model
import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
@@ -142,8 +144,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-def wrap_gradio_call(func):
- def f(*args, **kwargs):
+def wrap_gradio_call(func, extra_outputs=None):
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
@@ -159,7 +161,10 @@ def wrap_gradio_call(func):
shared.state.job = ""
shared.state.job_count = 0
- res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+
+ res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
@@ -179,6 +184,7 @@ def wrap_gradio_call(func):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False
+ shared.state.job_count = 0
return tuple(res)
@@ -187,7 +193,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part):
if shared.state.job_count == 0:
- return "", gr_show(False), gr_show(False)
+ return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0
@@ -219,13 +225,19 @@ def check_progress_call(id_part):
else:
preview_visibility = gr_show(True)
- return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part):
shared.state.job_count = -1
shared.state.current_latent = None
shared.state.current_image = None
+ shared.state.textinfo = None
return check_progress_call(id_part)
@@ -399,13 +411,16 @@ def create_toprow(is_img2img):
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
-def setup_progressbar(progressbar, preview, id_part):
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click(
fn=lambda: check_progress_call(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
@@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
-def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
@@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=txt2img,
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
txt2img_prompt,
@@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
)
img2img_args = dict(
- fn=img2img,
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
- fn=run_extras,
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
@@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change(
- fn=wrap_gradio_call(run_pnginfo),
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
@@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
-
+
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
@@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
-
+
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ with gr.Blocks() as textual_inversion_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column():
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
+
+ new_embedding_name = gr.Textbox(label="Name")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding = gr.Button(value="Create", variant='primary')
+
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
+ train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ steps = gr.Number(label='Max steps', value=100000, precision=0)
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ gr.HTML(value="")
+
+ with gr.Column():
+ with gr.Row():
+ interrupt_training = gr.Button(value="Interrupt")
+ train_embedding = gr.Button(value="Train", variant='primary')
+
+ with gr.Column():
+ progressbar = gr.HTML(elem_id="ti_progressbar")
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
+
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_preview = gr.Image(elem_id='ti_preview', visible=False)
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
+ setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
+
+ create_embedding.click(
+ fn=modules.textual_inversion.ui.create_embedding,
+ inputs=[
+ new_embedding_name,
+ nvpt,
+ ],
+ outputs=[
+ train_embedding_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_embedding.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_embedding_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ interrupt_training.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"),
]
@@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args):
try:
- results = run_modelmerger(*args)
+ results = modules.extras.run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- modules.sd_models.list_models() #To remove the potentially missing models from the list
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
diff --git a/style.css b/style.css
index 79d6bb0d..39586bf1 100644
--- a/style.css
+++ b/style.css
@@ -157,7 +157,7 @@ button{
max-width: 10em;
}
-#txt2img_preview, #img2img_preview{
+#txt2img_preview, #img2img_preview, #ti_preview{
position: absolute;
width: 320px;
left: 0;
@@ -172,18 +172,18 @@ button{
}
@media screen and (min-width: 768px) {
- #txt2img_preview, #img2img_preview {
+ #txt2img_preview, #img2img_preview, #ti_preview {
position: absolute;
}
}
@media screen and (max-width: 767px) {
- #txt2img_preview, #img2img_preview {
+ #txt2img_preview, #img2img_preview, #ti_preview {
position: relative;
}
}
-#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
+#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none;
}
@@ -247,7 +247,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{
}
-#txt2img_progressbar, #img2img_progressbar{
+#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute;
z-index: 1000;
right: 0;
diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt
new file mode 100644
index 00000000..15af2d6b
--- /dev/null
+++ b/textual_inversion_templates/style.txt
@@ -0,0 +1,19 @@
+a painting, art by [name]
+a rendering, art by [name]
+a cropped painting, art by [name]
+the painting, art by [name]
+a clean painting, art by [name]
+a dirty painting, art by [name]
+a dark painting, art by [name]
+a picture, art by [name]
+a cool painting, art by [name]
+a close-up painting, art by [name]
+a bright painting, art by [name]
+a cropped painting, art by [name]
+a good painting, art by [name]
+a close-up painting, art by [name]
+a rendition, art by [name]
+a nice painting, art by [name]
+a small painting, art by [name]
+a weird painting, art by [name]
+a large painting, art by [name]
diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt
new file mode 100644
index 00000000..b3a8159a
--- /dev/null
+++ b/textual_inversion_templates/style_filewords.txt
@@ -0,0 +1,19 @@
+a painting of [filewords], art by [name]
+a rendering of [filewords], art by [name]
+a cropped painting of [filewords], art by [name]
+the painting of [filewords], art by [name]
+a clean painting of [filewords], art by [name]
+a dirty painting of [filewords], art by [name]
+a dark painting of [filewords], art by [name]
+a picture of [filewords], art by [name]
+a cool painting of [filewords], art by [name]
+a close-up painting of [filewords], art by [name]
+a bright painting of [filewords], art by [name]
+a cropped painting of [filewords], art by [name]
+a good painting of [filewords], art by [name]
+a close-up painting of [filewords], art by [name]
+a rendition of [filewords], art by [name]
+a nice painting of [filewords], art by [name]
+a small painting of [filewords], art by [name]
+a weird painting of [filewords], art by [name]
+a large painting of [filewords], art by [name]
diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt
new file mode 100644
index 00000000..79f36aa0
--- /dev/null
+++ b/textual_inversion_templates/subject.txt
@@ -0,0 +1,27 @@
+a photo of a [name]
+a rendering of a [name]
+a cropped photo of the [name]
+the photo of a [name]
+a photo of a clean [name]
+a photo of a dirty [name]
+a dark photo of the [name]
+a photo of my [name]
+a photo of the cool [name]
+a close-up photo of a [name]
+a bright photo of the [name]
+a cropped photo of a [name]
+a photo of the [name]
+a good photo of the [name]
+a photo of one [name]
+a close-up photo of the [name]
+a rendition of the [name]
+a photo of the clean [name]
+a rendition of a [name]
+a photo of a nice [name]
+a good photo of a [name]
+a photo of the nice [name]
+a photo of the small [name]
+a photo of the weird [name]
+a photo of the large [name]
+a photo of a cool [name]
+a photo of a small [name]
diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt
new file mode 100644
index 00000000..008652a6
--- /dev/null
+++ b/textual_inversion_templates/subject_filewords.txt
@@ -0,0 +1,27 @@
+a photo of a [name], [filewords]
+a rendering of a [name], [filewords]
+a cropped photo of the [name], [filewords]
+the photo of a [name], [filewords]
+a photo of a clean [name], [filewords]
+a photo of a dirty [name], [filewords]
+a dark photo of the [name], [filewords]
+a photo of my [name], [filewords]
+a photo of the cool [name], [filewords]
+a close-up photo of a [name], [filewords]
+a bright photo of the [name], [filewords]
+a cropped photo of a [name], [filewords]
+a photo of the [name], [filewords]
+a good photo of the [name], [filewords]
+a photo of one [name], [filewords]
+a close-up photo of the [name], [filewords]
+a rendition of the [name], [filewords]
+a photo of the clean [name], [filewords]
+a rendition of a [name], [filewords]
+a photo of a nice [name], [filewords]
+a good photo of a [name], [filewords]
+a photo of the nice [name], [filewords]
+a photo of the small [name], [filewords]
+a photo of the weird [name], [filewords]
+a photo of the large [name], [filewords]
+a photo of a cool [name], [filewords]
+a photo of a small [name], [filewords]
diff --git a/webui.py b/webui.py
index b8cccd54..19fdcdd4 100644
--- a/webui.py
+++ b/webui.py
@@ -12,7 +12,6 @@ import modules.bsrgan_model as bsrgan
import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
-import modules.img2img
import modules.ldsr_model as ldsr
import modules.lowvram
import modules.realesrgan_model as realesrgan
@@ -21,7 +20,6 @@ import modules.sd_hijack
import modules.sd_models
import modules.shared as shared
import modules.swinir_model as swinir
-import modules.txt2img
import modules.ui
from modules import modelloader
from modules.paths import script_path
@@ -46,7 +44,7 @@ def wrap_queued_call(func):
return f
-def wrap_gradio_gpu_call(func):
+def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
devices.torch_gc()
@@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func):
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.interrupted = False
+ shared.state.textinfo = None
with queue_lock:
res = func(*args, **kwargs)
@@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func):
return res
- return modules.ui.wrap_gradio_call(f)
+ return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
@@ -86,13 +85,7 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler)
- demo = modules.ui.create_ui(
- txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
- img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
- run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
- run_pnginfo=modules.extras.run_pnginfo,
- run_modelmerger=modules.extras.run_modelmerger
- )
+ demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
demo.launch(
share=cmd_opts.share,