aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--README.md52
-rw-r--r--javascript/progressbar.js1
-rw-r--r--javascript/textualInversion.js8
-rw-r--r--javascript/ui.js10
-rw-r--r--modules/bsrgan_model.py2
-rw-r--r--modules/bsrgan_model_arch.py1
-rw-r--r--modules/devices.py3
-rw-r--r--modules/extras.py4
-rw-r--r--modules/ldsr_model.py15
-rw-r--r--modules/ldsr_model_arch.py4
-rw-r--r--modules/modelloader.py5
-rw-r--r--modules/processing.py17
-rw-r--r--modules/sd_hijack.py324
-rw-r--r--modules/sd_hijack_optimizations.py164
-rw-r--r--modules/sd_models.py5
-rw-r--r--modules/sd_samplers.py10
-rw-r--r--modules/shared.py5
-rw-r--r--modules/swinir_model.py27
-rw-r--r--modules/textual_inversion/dataset.py76
-rw-r--r--modules/textual_inversion/textual_inversion.py258
-rw-r--r--modules/textual_inversion/ui.py32
-rw-r--r--modules/ui.py141
-rw-r--r--scripts/outpainting_mk_2.py40
-rw-r--r--style.css12
-rw-r--r--textual_inversion_templates/style.txt19
-rw-r--r--textual_inversion_templates/style_filewords.txt19
-rw-r--r--textual_inversion_templates/subject.txt27
-rw-r--r--textual_inversion_templates/subject_filewords.txt27
-rw-r--r--webui.py15
30 files changed, 923 insertions, 401 deletions
diff --git a/.gitignore b/.gitignore
index 3532dab3..7afc9395 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,3 +25,4 @@ __pycache__
/.idea
notification.mp3
/SwinIR
+/textual_inversion
diff --git a/README.md b/README.md
index 219288ac..5ded94f9 100644
--- a/README.md
+++ b/README.md
@@ -11,44 +11,56 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
-- Prompt matrix
+- Prompt
- Stable Diffusion upscale
-- Attention
-- Loopback
-- X/Y plot
+- Attention, specify parts of text that the model should pay more attention to
+ - a man in a ((txuedo)) - will pay more attentinoto tuxedo
+ - a man in a (txuedo:1.21) - alternative syntax
+- Loopback, run img2img procvessing multiple times
+- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
+ - have as many embeddings as you want and use any names you like for them
+ - use multiple embeddings with different numbers of vectors per token
+ - works with half precision floating point numbers
- Extras tab with:
- GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler
- - ESRGAN, neural network with a lot of third party models
+ - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler
- LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options
- Sampling method selection
- Interrupt processing at any time
-- 4GB video card support
-- Correct seeds for batches
+- 4GB video card support (also reports of 2GB working)
+- Correct seeds for batches
- Prompt length validation
-- Generation parameters added as text to PNG
-- Tab to view an existing picture's generation parameters
+ - get length of prompt in tokensas you type
+ - get a warning after geenration if some text was truncated
+- Generation parameters
+ - parameters you used to generate images are saved with that image
+ - in PNG chunks for PNG, in EXIF for JPEG
+ - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
+ - can be disabled in settings
- Settings page
-- Running custom code from UI
+- Running arbitrary python code from UI (must run with commandline flag to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
-- Tiling support: UI checkbox to create images that can be tiled like textures
+- Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview
-- Negative prompt
-- Styles
-- Variations
-- Seed resizing
-- CLIP interrogator
-- Prompt Editing
-- Batch Processing
+- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
+- Styles, a way to save part of prompt and easily apply them via dropdown later
+- Variations, a way to generate same image but with tiny differences
+- Seed resizing, a way to generate same image but at slightly different resolution
+- CLIP interrogator, a button that tries to guess prompt from an image
+- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
+- Batch Processing, process a group of files using img2img
- Img2img Alternative
-- Highres Fix
-- LDSR Upscaling
+- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
+- Reloading checkpoints on the fly
+- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
+- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 21f25b38..1e297abb 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
+ check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
new file mode 100644
index 00000000..8061be08
--- /dev/null
+++ b/javascript/textualInversion.js
@@ -0,0 +1,8 @@
+
+
+function start_training_textual_inversion(){
+ requestProgress('ti')
+ gradioApp().querySelector('#ti_error').innerHTML=''
+
+ return args_to_array(arguments)
+}
diff --git a/javascript/ui.js b/javascript/ui.js
index 562d2552..bfe02410 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -186,10 +186,12 @@ onUiUpdate(function(){
if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
+ txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
}
if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
+ img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
}
})
@@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
+function submit_prompt(event, generate_button_id) {
+ if (event.altKey && event.keyCode === 13) {
+ event.preventDefault();
+ gradioApp().getElementById(generate_button_id).click();
+ return;
+ }
+}
+
function update_token_counter(button_id) {
if (token_timeout)
clearTimeout(token_timeout);
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
index 47346f31..e62c6657 100644
--- a/modules/bsrgan_model.py
+++ b/modules/bsrgan_model.py
@@ -69,7 +69,7 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
if not os.path.exists(filename) or filename is None:
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
return None
- model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
+ model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
index d72647db..cb4d1c13 100644
--- a/modules/bsrgan_model_arch.py
+++ b/modules/bsrgan_model_arch.py
@@ -76,7 +76,6 @@ class RRDBNet(nn.Module):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
- print([in_nc, out_nc, nf, nb, gc, sf])
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
diff --git a/modules/devices.py b/modules/devices.py
index 07bb2339..ff82f2f6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-
device = get_optimal_device()
device_codeformer = cpu if has_mps else device
-
+dtype = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
diff --git a/modules/extras.py b/modules/extras.py
index 1bff5874..6a0d5cb0 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -191,9 +191,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
if save_as_half:
theta_0[key] = theta_0[key].half()
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
- output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
+ output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py
index 877e7e73..1c1070fc 100644
--- a/modules/ldsr_model.py
+++ b/modules/ldsr_model.py
@@ -22,8 +22,20 @@ class UpscalerLDSR(Upscaler):
self.scalers = [scaler_data]
def load_model(self, path: str):
+ # Remove incorrect project.yaml file if too big
+ yaml_path = os.path.join(self.model_path, "project.yaml")
+ old_model_path = os.path.join(self.model_path, "model.pth")
+ new_model_path = os.path.join(self.model_path, "model.ckpt")
+ if os.path.exists(yaml_path):
+ statinfo = os.stat(yaml_path)
+ if statinfo.st_size >= 10485760:
+ print("Removing invalid LDSR YAML file.")
+ os.remove(yaml_path)
+ if os.path.exists(old_model_path):
+ print("Renaming model from model.pth to model.ckpt")
+ os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
- file_name="model.pth", progress=True)
+ file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True)
@@ -41,5 +53,4 @@ class UpscalerLDSR(Upscaler):
print("NO LDSR!")
return img
ddim_steps = shared.opts.ldsr_steps
- pre_scale = shared.opts.ldsr_pre_down
return ldsr.super_resolution(img, ddim_steps, self.scale)
diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py
index 7faac6e1..14db5076 100644
--- a/modules/ldsr_model_arch.py
+++ b/modules/ldsr_model_arch.py
@@ -98,9 +98,7 @@ class LDSR:
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
- print("Foo")
down_sample_rate = target_scale / 4
- print(f"Downsample rate is {down_sample_rate}")
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(wd)
@@ -111,7 +109,7 @@ class LDSR:
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
- print(f"Down sample rate is 1 from {target_scale} / 4")
+ print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
sample = logs["sample"]
diff --git a/modules/modelloader.py b/modules/modelloader.py
index b1721671..b0f2f33d 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,4 @@
+import glob
import os
import shutil
import importlib
@@ -40,8 +41,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places:
if os.path.exists(place):
- for file in os.listdir(place):
- full_path = os.path.join(place, file)
+ for file in glob.iglob(place + '**/**', recursive=True):
+ full_path = file
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
diff --git a/modules/processing.py b/modules/processing.py
index 7eeb5191..0a4b6198 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.styles: str = styles
+ self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -79,7 +79,7 @@ class StableDiffusionProcessing:
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
-
+ self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
@@ -130,7 +130,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
-
+ self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
- "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p)
- os.makedirs(p.outpath_samples, exist_ok=True)
- os.makedirs(p.outpath_grids, exist_ok=True)
+ if p.outpath_samples is not None:
+ os.makedirs(p.outpath_samples, exist_ok=True)
+
+ if p.outpath_grids is not None:
+ os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
- model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa7eaeb8..fd57e5c5 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -6,244 +6,41 @@ import torch
import numpy as np
from torch import einsum
-from modules import prompt_parser
+import modules.textual_inversion.textual_inversion
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
-from ldm.util import default
-from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
+diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
-# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
- h = self.heads
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
+def apply_optimizations():
+ if cmd_opts.opt_split_attention_v1:
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
+def undo_optimizations():
+ ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
- s2 = s1.softmax(dim=-1)
- del s1
-
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# taken from https://github.com/Doggettx/stable-diffusion
-def split_cross_attention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
-
- del q, k, v
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
-def cross_attention_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q1 = self.q(h_)
- k1 = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q1.shape
-
- q2 = q1.reshape(b, c, h*w)
- del q1
-
- q = q2.permute(0, 2, 1) # b,hw,c
- del q2
-
- k = k1.reshape(b, c, h*w) # b,c,hw
- del k1
-
- h_ = torch.zeros_like(k, device=q.device)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
- mem_required = tensor_size * 2.5
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
-
- w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w2 = w1 * (int(c)**(-0.5))
- del w1
- w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
- del w2
-
- # attend to values
- v1 = v.reshape(b, c, h*w)
- w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- del w3
-
- h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- del v1, w4
-
- h2 = h_.reshape(b, c, h, w)
- del h_
-
- h3 = self.proj_out(h2)
- del h2
-
- h3 += x
-
- return h3
class StableDiffusionModelHijack:
- ids_lookup = {}
- word_embeddings = {}
- word_embeddings_checksums = {}
fixes = None
comments = []
- dir_mtime = None
layers = None
circular_enabled = False
clip = None
- def load_textual_inversion_embeddings(self, dirname, model):
- mt = os.path.getmtime(dirname)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
- return
-
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
-
- tokenizer = model.cond_stage_model.tokenizer
-
- def const_hash(a):
- r = 0
- for v in a:
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
- return r
-
- def process_file(path, filename):
- name = os.path.splitext(filename)[0]
-
- data = torch.load(path, map_location="cpu")
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
-
- self.word_embeddings[name] = emb.detach().to(device)
- self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
-
- ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
-
- first_id = ids[0]
- if first_id not in self.ids_lookup:
- self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, name))
-
- for fn in os.listdir(dirname):
- try:
- fullfn = os.path.join(dirname, fn)
-
- if os.stat(fullfn).st_size == 0:
- continue
-
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model
- if cmd_opts.opt_split_attention_v1:
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
- ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.hijack = hijack
+ self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
@@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
-
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
- if possible_matches is None:
+ if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i + len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
- elif possible_matches is None:
+ i += 1
+ elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i+len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(mult)
-
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
@@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
@@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids)
- if batch_fixes is not None:
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, word in fixes:
- emb = self.embeddings.word_embeddings[word]
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
+ return inputs_embeds
+
+ vecs = []
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, embedding in fixes:
+ emb = embedding.vec
+ emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+
+ vecs.append(tensor)
- return inputs_embeds
+ return torch.stack(vecs)
def add_circular_option_to_conv_2d():
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
new file mode 100644
index 00000000..9c079e57
--- /dev/null
+++ b/modules/sd_hijack_optimizations.py
@@ -0,0 +1,164 @@
+import math
+import torch
+from torch import einsum
+
+from ldm.util import default
+from einops import rearrange
+
+
+# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
+def split_cross_attention_forward_v1(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+
+# taken from https://github.com/Doggettx/stable-diffusion
+def split_cross_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q_in = self.to_q(x)
+ context = default(context, x)
+ k_in = self.to_k(context) * self.scale
+ v_in = self.to_v(context)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+
+ del q, k, v
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+def nonlinearity_hijack(x):
+ # swish
+ t = torch.sigmoid(x)
+ x *= t
+ del t
+
+ return x
+
+def cross_attention_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q1.shape
+
+ q2 = q1.reshape(b, c, h*w)
+ del q1
+
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
+
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ab014efb..5b3dbdc7 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader
+from modules import shared, modelloader, devices
from modules.paths import models_path
model_dir = "Stable-diffusion"
@@ -69,6 +69,7 @@ def list_models():
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
+ shared.opts.sd_model_checkpoint = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
@@ -133,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half:
model.half()
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+
model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index dff89c09..92522214 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -290,7 +290,10 @@ class KDiffusionSampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps)
- sigmas = self.model_wrap.get_sigmas(steps)
+ if p.sampler_noise_scheduler_override:
+ sigmas = p.sampler_noise_scheduler_override(steps)
+ else:
+ sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise
@@ -306,7 +309,10 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
- sigmas = self.model_wrap.get_sigmas(steps)
+ if p.sampler_noise_scheduler_override:
+ sigmas = p.sampler_noise_scheduler_override(steps)
+ else:
+ sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
diff --git a/modules/shared.py b/modules/shared.py
index a48b995a..5a591dc9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -20,7 +20,7 @@ default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
-parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
+parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
@@ -79,6 +79,7 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ textinfo = None
def interrupt(self):
self.interrupted = True
@@ -89,7 +90,7 @@ class State:
self.current_image_sampling_step = 0
def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State()
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index 41fda5a7..9bd454c6 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -5,6 +5,7 @@ import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
+from tqdm import tqdm
from modules import modelloader
from modules.paths import models_path
@@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
- for h_idx in h_idx_list:
- for w_idx in w_idx_list:
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
+ for h_idx in h_idx_list:
+ for w_idx in w_idx_list:
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
+ out_patch = model(in_patch)
+ out_patch_mask = torch.ones_like(out_patch)
+
+ E[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch)
+ W[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch_mask)
+ pbar.update(1)
output = E.div_(W)
return output
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
new file mode 100644
index 00000000..7e134a08
--- /dev/null
+++ b/modules/textual_inversion/dataset.py
@@ -0,0 +1,76 @@
+import os
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+import random
+import tqdm
+
+
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+
+ self.placeholder_token = placeholder_token
+
+ self.size = size
+ self.width = width
+ self.height = height
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ self.dataset = []
+
+ with open(template_file, "r") as file:
+ lines = [x.strip() for x in file.readlines()]
+
+ self.lines = lines
+
+ assert data_root, 'dataset directory not specified'
+
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ print("Preparing dataset...")
+ for path in tqdm.tqdm(self.image_paths):
+ image = Image.open(path)
+ image = image.convert('RGB')
+ image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+
+ filename = os.path.basename(path)
+ filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
+ filename_tokens = [token for token in filename_tokens if token.isalpha()]
+
+ npimage = np.array(image).astype(np.uint8)
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
+
+ torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
+ torchdata = torch.moveaxis(torchdata, 2, 0)
+
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+
+ self.dataset.append((init_latent, filename_tokens))
+
+ self.length = len(self.dataset) * repeats
+
+ self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.indexes = None
+ self.shuffle()
+
+ def shuffle(self):
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ if i % len(self.dataset) == 0:
+ self.shuffle()
+
+ index = self.indexes[i % len(self.indexes)]
+ x, filename_tokens = self.dataset[index]
+
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+
+ return x, text
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
new file mode 100644
index 00000000..c0baaace
--- /dev/null
+++ b/modules/textual_inversion/textual_inversion.py
@@ -0,0 +1,258 @@
+import os
+import sys
+import traceback
+
+import torch
+import tqdm
+import html
+import datetime
+
+from modules import shared, devices, sd_hijack, processing
+import modules.textual_inversion.dataset
+
+
+class Embedding:
+ def __init__(self, vec, name, step=None):
+ self.vec = vec
+ self.name = name
+ self.step = step
+ self.cached_checksum = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ }
+
+ torch.save(embedding_data, filename)
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+class EmbeddingDatabase:
+ def __init__(self, embeddings_dir):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.dir_mtime = None
+ self.embeddings_dir = embeddings_dir
+
+ def register_embedding(self, embedding, model):
+
+ self.word_embeddings[embedding.name] = embedding
+
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+ self.ids_lookup[first_id].append((ids, embedding))
+
+ return embedding
+
+ def load_textual_inversion_embeddings(self):
+ mt = os.path.getmtime(self.embeddings_dir)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = torch.load(path, map_location="cpu")
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ self.register_embedding(embedding, shared.sd_model)
+
+ for fn in os.listdir(self.embeddings_dir):
+ try:
+ fullfn = os.path.join(self.embeddings_dir, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ return None
+
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding
+
+ return None
+
+
+
+def create_embedding(name, num_vectors_per_token):
+ init_text = '*'
+
+ cond_model = shared.sd_model.cond_stage_model
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ embedding = Embedding(vec, name)
+ embedding.step = 0
+ embedding.save(fn)
+
+ return fn
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+ assert embedding_name, 'embedding not selected'
+
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ embedding.vec.requires_grad = True
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ return embedding, filename
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, (x, text) in pbar:
+ embedding.step = i + ititial_step
+
+ if embedding.step > steps:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([text])
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ embedding.save(last_saved_file)
+
+ if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {embedding.step}<br/>
+Last prompt: {html.escape(text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ embedding.cached_checksum = None
+ embedding.save(filename)
+
+ return embedding, filename
+
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
new file mode 100644
index 00000000..ce3677a9
--- /dev/null
+++ b/modules/textual_inversion/ui.py
@@ -0,0 +1,32 @@
+import html
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion as ti
+from modules import sd_hijack, shared
+
+
+def create_embedding(name, nvpt):
+ filename = ti.create_embedding(name, nvpt)
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+
+
+def train_embedding(*args):
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ embedding, filename = ti.train_embedding(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
+Embedding saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ sd_hijack.apply_optimizations()
diff --git a/modules/ui.py b/modules/ui.py
index 249b3eea..57aef6ff 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -21,6 +21,7 @@ import gradio as gr
import gradio.utils
import gradio.routes
+from modules import sd_hijack
from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -32,6 +33,7 @@ import modules.gfpgan_model
import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
@@ -142,8 +144,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-def wrap_gradio_call(func):
- def f(*args, **kwargs):
+def wrap_gradio_call(func, extra_outputs=None):
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
@@ -159,7 +161,10 @@ def wrap_gradio_call(func):
shared.state.job = ""
shared.state.job_count = 0
- res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+
+ res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
@@ -179,6 +184,7 @@ def wrap_gradio_call(func):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False
+ shared.state.job_count = 0
return tuple(res)
@@ -187,7 +193,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part):
if shared.state.job_count == 0:
- return "", gr_show(False), gr_show(False)
+ return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0
@@ -219,13 +225,19 @@ def check_progress_call(id_part):
else:
preview_visibility = gr_show(True)
- return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part):
shared.state.job_count = -1
shared.state.current_latent = None
shared.state.current_image = None
+ shared.state.textinfo = None
return check_progress_call(id_part)
@@ -380,7 +392,7 @@ def create_toprow(is_img2img):
with gr.Column(scale=1):
with gr.Row():
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
- submit = gr.Button('Generate', elem_id="generate", variant='primary')
+ submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
interrupt.click(
fn=lambda: shared.state.interrupt(),
@@ -399,13 +411,16 @@ def create_toprow(is_img2img):
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
-def setup_progressbar(progressbar, preview, id_part):
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click(
fn=lambda: check_progress_call(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
@@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
-def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
@@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=txt2img,
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
txt2img_prompt,
@@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
)
img2img_args = dict(
- fn=img2img,
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
- fn=run_extras,
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
@@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change(
- fn=wrap_gradio_call(run_pnginfo),
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
@@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
-
+
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
@@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
-
+
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ with gr.Blocks() as textual_inversion_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column():
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
+
+ new_embedding_name = gr.Textbox(label="Name")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding = gr.Button(value="Create", variant='primary')
+
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
+ train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ steps = gr.Number(label='Max steps', value=100000, precision=0)
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ gr.HTML(value="")
+
+ with gr.Column():
+ with gr.Row():
+ interrupt_training = gr.Button(value="Interrupt")
+ train_embedding = gr.Button(value="Train", variant='primary')
+
+ with gr.Column():
+ progressbar = gr.HTML(elem_id="ti_progressbar")
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
+
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_preview = gr.Image(elem_id='ti_preview', visible=False)
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
+ setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
+
+ create_embedding.click(
+ fn=modules.textual_inversion.ui.create_embedding,
+ inputs=[
+ new_embedding_name,
+ nvpt,
+ ],
+ outputs=[
+ train_embedding_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_embedding.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_embedding_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ interrupt_training.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"),
]
@@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args):
try:
- results = run_modelmerger(*args)
+ results = modules.extras.run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- modules.sd_models.list_models() #To remove the potentially missing models from the list
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index 9719bb8f..11613ca3 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -11,46 +11,8 @@ from modules import images, processing, devices
from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
-# https://github.com/parlance-zz/g-diffuser-bot
-def expand(x, dir, amount, power=0.75):
- is_left = dir == 3
- is_right = dir == 1
- is_up = dir == 0
- is_down = dir == 2
-
- if is_left or is_right:
- noise = np.zeros((x.shape[0], amount, 3), dtype=float)
- indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
- if is_right:
- indexes = 1 - indexes
- indexes = (indexes * (x.shape[1] - 1)).astype(int)
-
- for row in range(x.shape[0]):
- if is_left:
- noise[row] = x[row][indexes[row]]
- else:
- noise[row] = np.flip(x[row][indexes[row]], axis=0)
-
- x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
- return x
-
- if is_up or is_down:
- noise = np.zeros((amount, x.shape[1], 3), dtype=float)
- indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
- if is_down:
- indexes = 1 - indexes
- indexes = (indexes * x.shape[0] - 1).astype(int)
-
- for row in range(x.shape[1]):
- if is_up:
- noise[:, row] = x[:, row][indexes[row]]
- else:
- noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
-
- x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
- return x
-
+# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
diff --git a/style.css b/style.css
index 9709c4ee..39586bf1 100644
--- a/style.css
+++ b/style.css
@@ -23,7 +23,7 @@
text-align: right;
}
-#generate{
+#txt2img_generate, #img2img_generate {
min-height: 4.5em;
}
@@ -157,7 +157,7 @@ button{
max-width: 10em;
}
-#txt2img_preview, #img2img_preview{
+#txt2img_preview, #img2img_preview, #ti_preview{
position: absolute;
width: 320px;
left: 0;
@@ -172,18 +172,18 @@ button{
}
@media screen and (min-width: 768px) {
- #txt2img_preview, #img2img_preview {
+ #txt2img_preview, #img2img_preview, #ti_preview {
position: absolute;
}
}
@media screen and (max-width: 767px) {
- #txt2img_preview, #img2img_preview {
+ #txt2img_preview, #img2img_preview, #ti_preview {
position: relative;
}
}
-#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
+#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none;
}
@@ -247,7 +247,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{
}
-#txt2img_progressbar, #img2img_progressbar{
+#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute;
z-index: 1000;
right: 0;
diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt
new file mode 100644
index 00000000..15af2d6b
--- /dev/null
+++ b/textual_inversion_templates/style.txt
@@ -0,0 +1,19 @@
+a painting, art by [name]
+a rendering, art by [name]
+a cropped painting, art by [name]
+the painting, art by [name]
+a clean painting, art by [name]
+a dirty painting, art by [name]
+a dark painting, art by [name]
+a picture, art by [name]
+a cool painting, art by [name]
+a close-up painting, art by [name]
+a bright painting, art by [name]
+a cropped painting, art by [name]
+a good painting, art by [name]
+a close-up painting, art by [name]
+a rendition, art by [name]
+a nice painting, art by [name]
+a small painting, art by [name]
+a weird painting, art by [name]
+a large painting, art by [name]
diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt
new file mode 100644
index 00000000..b3a8159a
--- /dev/null
+++ b/textual_inversion_templates/style_filewords.txt
@@ -0,0 +1,19 @@
+a painting of [filewords], art by [name]
+a rendering of [filewords], art by [name]
+a cropped painting of [filewords], art by [name]
+the painting of [filewords], art by [name]
+a clean painting of [filewords], art by [name]
+a dirty painting of [filewords], art by [name]
+a dark painting of [filewords], art by [name]
+a picture of [filewords], art by [name]
+a cool painting of [filewords], art by [name]
+a close-up painting of [filewords], art by [name]
+a bright painting of [filewords], art by [name]
+a cropped painting of [filewords], art by [name]
+a good painting of [filewords], art by [name]
+a close-up painting of [filewords], art by [name]
+a rendition of [filewords], art by [name]
+a nice painting of [filewords], art by [name]
+a small painting of [filewords], art by [name]
+a weird painting of [filewords], art by [name]
+a large painting of [filewords], art by [name]
diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt
new file mode 100644
index 00000000..79f36aa0
--- /dev/null
+++ b/textual_inversion_templates/subject.txt
@@ -0,0 +1,27 @@
+a photo of a [name]
+a rendering of a [name]
+a cropped photo of the [name]
+the photo of a [name]
+a photo of a clean [name]
+a photo of a dirty [name]
+a dark photo of the [name]
+a photo of my [name]
+a photo of the cool [name]
+a close-up photo of a [name]
+a bright photo of the [name]
+a cropped photo of a [name]
+a photo of the [name]
+a good photo of the [name]
+a photo of one [name]
+a close-up photo of the [name]
+a rendition of the [name]
+a photo of the clean [name]
+a rendition of a [name]
+a photo of a nice [name]
+a good photo of a [name]
+a photo of the nice [name]
+a photo of the small [name]
+a photo of the weird [name]
+a photo of the large [name]
+a photo of a cool [name]
+a photo of a small [name]
diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt
new file mode 100644
index 00000000..008652a6
--- /dev/null
+++ b/textual_inversion_templates/subject_filewords.txt
@@ -0,0 +1,27 @@
+a photo of a [name], [filewords]
+a rendering of a [name], [filewords]
+a cropped photo of the [name], [filewords]
+the photo of a [name], [filewords]
+a photo of a clean [name], [filewords]
+a photo of a dirty [name], [filewords]
+a dark photo of the [name], [filewords]
+a photo of my [name], [filewords]
+a photo of the cool [name], [filewords]
+a close-up photo of a [name], [filewords]
+a bright photo of the [name], [filewords]
+a cropped photo of a [name], [filewords]
+a photo of the [name], [filewords]
+a good photo of the [name], [filewords]
+a photo of one [name], [filewords]
+a close-up photo of the [name], [filewords]
+a rendition of the [name], [filewords]
+a photo of the clean [name], [filewords]
+a rendition of a [name], [filewords]
+a photo of a nice [name], [filewords]
+a good photo of a [name], [filewords]
+a photo of the nice [name], [filewords]
+a photo of the small [name], [filewords]
+a photo of the weird [name], [filewords]
+a photo of the large [name], [filewords]
+a photo of a cool [name], [filewords]
+a photo of a small [name], [filewords]
diff --git a/webui.py b/webui.py
index ebe39a17..424ab975 100644
--- a/webui.py
+++ b/webui.py
@@ -7,6 +7,7 @@ import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
+
import modules.lowvram
import modules.paths
import modules.scripts
@@ -14,6 +15,7 @@ import modules.sd_hijack
import modules.sd_models
import modules.shared as shared
import modules.txt2img
+
import modules.ui
from modules import devices
from modules import modelloader
@@ -39,7 +41,7 @@ def wrap_queued_call(func):
return f
-def wrap_gradio_gpu_call(func):
+def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
devices.torch_gc()
@@ -51,6 +53,7 @@ def wrap_gradio_gpu_call(func):
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.interrupted = False
+ shared.state.textinfo = None
with queue_lock:
res = func(*args, **kwargs)
@@ -62,7 +65,7 @@ def wrap_gradio_gpu_call(func):
return res
- return modules.ui.wrap_gradio_call(f)
+ return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
@@ -79,13 +82,7 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler)
- demo = modules.ui.create_ui(
- txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
- img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
- run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
- run_pnginfo=modules.extras.run_pnginfo,
- run_modelmerger=modules.extras.run_modelmerger
- )
+ demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
demo.launch(
share=cmd_opts.share,