aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/deepbooru.py2
-rw-r--r--modules/devices.py2
-rw-r--r--modules/extras.py46
-rw-r--r--modules/hypernetworks/hypernetwork.py57
-rw-r--r--modules/images.py24
-rw-r--r--modules/images_history.py181
-rw-r--r--modules/interrogate.py14
-rw-r--r--modules/processing.py90
-rw-r--r--modules/safe.py9
-rw-r--r--modules/sd_models.py5
-rw-r--r--modules/shared.py12
-rw-r--r--modules/textual_inversion/dataset.py32
-rw-r--r--modules/textual_inversion/textual_inversion.py70
-rw-r--r--modules/txt2img.py5
-rw-r--r--modules/ui.py142
15 files changed, 554 insertions, 137 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index f34f3788..4ad334a1 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
- model_path, compile_model=True
+ model_path, compile_model=False
)
return model, tags
diff --git a/modules/devices.py b/modules/devices.py
index 03ef58f1..eb422583 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -34,7 +34,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
dtype_vae = torch.float16
diff --git a/modules/extras.py b/modules/extras.py
index b24d7de3..f2f5a7b0 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -159,48 +159,52 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
- # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
- def weighted_sum(theta0, theta1, alpha):
+def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
+ def weighted_sum(theta0, theta1, theta2, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
- # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def sigmoid(theta0, theta1, alpha):
- alpha = alpha * alpha * (3 - (2 * alpha))
- return theta0 + ((theta1 - theta0) * alpha)
-
- # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def inv_sigmoid(theta0, theta1, alpha):
- import math
- alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
- return theta0 + ((theta1 - theta0) * alpha)
+ def add_difference(theta0, theta1, theta2, alpha):
+ return theta0 + (theta1 - theta2) * alpha
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
+ teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
-
- theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
+ if teritary_model_info is not None:
+ print(f"Loading {teritary_model_info.filename}...")
+ teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
+ theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
+ else:
+ theta_2 = None
+
theta_funcs = {
- "Weighted Sum": weighted_sum,
- "Sigmoid": sigmoid,
- "Inverse Sigmoid": inv_sigmoid,
+ "Weighted sum": weighted_sum,
+ "Add difference": add_difference,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ t2 = (theta_2 or {}).get(key)
+ if t2 is None:
+ t2 = torch.zeros_like(theta_0[key])
+
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
+
if save_as_half:
theta_0[key] = theta_0[key].half()
+ # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
@@ -209,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
+ filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename)
@@ -219,4 +223,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models()
print(f"Checkpoint saved.")
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f1248bb7..a2b3bc0a 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -5,6 +5,7 @@ import os
import sys
import traceback
import tqdm
+import csv
import torch
@@ -14,6 +15,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -180,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out)
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+
+ return torch.stack(conds)
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -209,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -233,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, entry in pbar:
+ for i, entries in pbar:
hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step)
@@ -244,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break
with torch.autocast("cuda"):
- cond = entry.cond.to(devices.device)
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), cond)[0]
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
- del cond
+ del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
@@ -262,23 +279,39 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
-
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
@@ -297,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(entry.cond_text)}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/images.py b/modules/images.py
index c0a90676..b9589563 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,5 @@
import datetime
+import io
import math
import os
from collections import namedtuple
@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None):
rows = opts.n_rows
elif opts.n_rows == 0:
rows = batch_size
+ elif opts.grid_prevent_empty_spots:
+ rows = math.floor(math.sqrt(len(imgs)))
+ while len(imgs) % rows != 0:
+ rows -= 1
else:
rows = math.sqrt(len(imgs))
rows = round(rows)
@@ -463,3 +468,22 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn = None
return fullfn, txt_fullfn
+
+
+def image_data(data):
+ try:
+ image = Image.open(io.BytesIO(data))
+ textinfo = image.text["parameters"]
+ return textinfo, None
+ except Exception:
+ pass
+
+ try:
+ text = data.decode('utf8')
+ assert len(text) < 10000
+ return text, None
+
+ except Exception:
+ pass
+
+ return '', None
diff --git a/modules/images_history.py b/modules/images_history.py
new file mode 100644
index 00000000..9260df8a
--- /dev/null
+++ b/modules/images_history.py
@@ -0,0 +1,181 @@
+import os
+import shutil
+
+
+def traverse_all_files(output_dir, image_list, curr_dir=None):
+ curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
+ try:
+ f_list = os.listdir(curr_path)
+ except:
+ if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
+ image_list.append(curr_dir)
+ return image_list
+ for file in f_list:
+ file = file if curr_dir is None else os.path.join(curr_dir, file)
+ file_path = os.path.join(curr_path, file)
+ if file[-4:] == ".txt":
+ pass
+ elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
+ image_list.append(file)
+ else:
+ image_list = traverse_all_files(output_dir, image_list, file)
+ return image_list
+
+
+def get_recent_images(dir_name, page_index, step, image_index, tabname):
+ page_index = int(page_index)
+ f_list = os.listdir(dir_name)
+ image_list = []
+ image_list = traverse_all_files(dir_name, image_list)
+ image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
+ num = 48 if tabname != "extras" else 12
+ max_page_index = len(image_list) // num + 1
+ page_index = max_page_index if page_index == -1 else page_index + step
+ page_index = 1 if page_index < 1 else page_index
+ page_index = max_page_index if page_index > max_page_index else page_index
+ idx_frm = (page_index - 1) * num
+ image_list = image_list[idx_frm:idx_frm + num]
+ image_index = int(image_index)
+ if image_index < 0 or image_index > len(image_list) - 1:
+ current_file = None
+ hidden = None
+ else:
+ current_file = image_list[int(image_index)]
+ hidden = os.path.join(dir_name, current_file)
+ return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
+
+
+def first_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, 1, 0, image_index, tabname)
+
+
+def end_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, -1, 0, image_index, tabname)
+
+
+def prev_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, page_index, -1, image_index, tabname)
+
+
+def next_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, page_index, 1, image_index, tabname)
+
+
+def page_index_change(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, page_index, 0, image_index, tabname)
+
+
+def show_image_info(num, image_path, filenames):
+ # print(f"select image {num}")
+ file = filenames[int(num)]
+ return file, num, os.path.join(image_path, file)
+
+
+def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
+ if name == "":
+ return filenames, delete_num
+ else:
+ delete_num = int(delete_num)
+ index = list(filenames).index(name)
+ i = 0
+ new_file_list = []
+ for name in filenames:
+ if i >= index and i < index + delete_num:
+ path = os.path.join(dir_name, name)
+ if os.path.exists(path):
+ print(f"Delete file {path}")
+ os.remove(path)
+ txt_file = os.path.splitext(path)[0] + ".txt"
+ if os.path.exists(txt_file):
+ os.remove(txt_file)
+ else:
+ print(f"Not exists file {path}")
+ else:
+ new_file_list.append(name)
+ i += 1
+ return new_file_list, 1
+
+
+def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
+ if opts.outdir_samples != "":
+ dir_name = opts.outdir_samples
+ elif tabname == "txt2img":
+ dir_name = opts.outdir_txt2img_samples
+ elif tabname == "img2img":
+ dir_name = opts.outdir_img2img_samples
+ elif tabname == "extras":
+ dir_name = opts.outdir_extras_samples
+ d = dir_name.split("/")
+ dir_name = "/" if dir_name.startswith("/") else d[0]
+ for p in d[1:]:
+ dir_name = os.path.join(dir_name, p)
+ with gr.Row():
+ renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+ with gr.Row(elem_id=tabname + "_images_history"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
+ with gr.Row():
+ delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
+ delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
+ with gr.Column():
+ with gr.Row():
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ with gr.Row():
+ with gr.Column():
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False)
+ img_file_name = gr.Textbox(label="File Name", interactive=False)
+ with gr.Row():
+ # hiden items
+
+ img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
+ tabname_box = gr.Textbox(tabname, visible=False)
+ image_index = gr.Textbox(value=-1, visible=False)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
+ filenames = gr.State()
+ hidden = gr.Image(type="pil", visible=False)
+ info1 = gr.Textbox(visible=False)
+ info2 = gr.Textbox(visible=False)
+
+ # turn pages
+ gallery_inputs = [img_path, page_index, image_index, tabname_box]
+ gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
+
+ first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
+
+ # other funcitons
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
+ img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
+ delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
+ hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+
+ # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+ switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
+ switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+
+
+def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
+ with gr.Blocks(analytics_enabled=False) as images_history:
+ with gr.Tabs() as tabs:
+ with gr.Tab("txt2img history"):
+ with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
+ show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
+ with gr.Tab("img2img history"):
+ with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
+ with gr.Tab("extras history"):
+ with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
+ return images_history
diff --git a/modules/interrogate.py b/modules/interrogate.py
index af858cc0..9263d65a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -55,7 +55,7 @@ class InterrogateModels:
model, preprocess = clip.load(clip_model_name)
model.eval()
- model = model.to(shared.device)
+ model = model.to(devices.device_interrogate)
return model, preprocess
@@ -65,14 +65,14 @@ class InterrogateModels:
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
- self.blip_model = self.blip_model.to(shared.device)
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
- self.clip_model = self.clip_model.to(shared.device)
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype
@@ -99,11 +99,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
- similarity = torch.zeros((1, len(text_array))).to(shared.device)
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
@@ -116,7 +116,7 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
@@ -140,7 +140,7 @@ class InterrogateModels:
res = caption
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
diff --git a/modules/processing.py b/modules/processing.py
index d5172f00..7e2a416d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -140,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
@@ -501,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- firstphase_width = 0
- firstphase_height = 0
- firstphase_width_truncated = 0
- firstphase_height_truncated = 0
- def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
+ def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
- self.scale_latent = scale_latent
self.denoising_strength = denoising_strength
+ self.firstphase_width = firstphase_width
+ self.firstphase_height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -519,14 +518,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ if self.firstphase_width == 0 or self.firstphase_height == 0:
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = self.width * self.height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ self.firstphase_width = math.ceil(scale * self.width / 64) * 64
+ self.firstphase_height = math.ceil(scale * self.height / 64) * 64
+ firstphase_width_truncated = int(scale * self.width)
+ firstphase_height_truncated = int(scale * self.height)
+
+ else:
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
+ width_ratio = self.width / self.firstphase_width
+ height_ratio = self.height / self.firstphase_height
+
+ if width_ratio > height_ratio:
+ firstphase_width_truncated = self.firstphase_width
+ firstphase_height_truncated = self.firstphase_width * self.height / self.width
+ else:
+ firstphase_width_truncated = self.firstphase_height * self.width / self.height
+ firstphase_height_truncated = self.firstphase_height
+
+ self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
+ self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- self.firstphase_width_truncated = int(scale * self.width)
- self.firstphase_height_truncated = int(scale * self.height)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -539,36 +555,30 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
- truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
- samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
+ decoded_samples = decode_first_stage(self.sd_model, samples)
- if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
- else:
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
-
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
-
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ batch_images = []
+ for i, x_sample in enumerate(lowres_samples):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(0, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
+
+ decoded_samples = torch.from_numpy(np.array(batch_images))
+ decoded_samples = decoded_samples.to(shared.device)
+ decoded_samples = 2. * decoded_samples - 1.
+
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob()
diff --git a/modules/safe.py b/modules/safe.py
index 20be16a5..399165a1 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -96,11 +96,18 @@ def load(filename, *args, **kwargs):
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename)
+ except pickle.UnpicklingError:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ return None
+
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f3660d8d..3aa21ec1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -136,7 +136,7 @@ def load_model_weights(model, checkpoint_info):
if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
- pl_sd = torch.load(checkpoint_file, map_location="cpu")
+ pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
@@ -159,9 +159,8 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
- vae_ckpt = torch.load(vae_file, map_location="cpu")
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
-
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
diff --git a/modules/shared.py b/modules/shared.py
index b2090da1..aa69bedf 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -34,6 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
+parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
@@ -54,7 +55,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -76,10 +77,11 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
cmd_opts = parser.parse_args()
-devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
+weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
@@ -175,6 +177,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
+ "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -233,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
- "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 67e90afe..23bb4b6a 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
+ self.batch_size = batch_size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
if include_cond:
entry.cond_text = self.create_text(filename_text)
- entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
- self.length = len(self.dataset) * repeats
+ assert len(self.dataset) > 1, "No images have been found in the dataset."
+ self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(self.length) % len(self.dataset)
+ self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None
self.shuffle()
@@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
return self.length
def __getitem__(self, i):
- if i % len(self.dataset) == 0:
- self.shuffle()
+ res = []
- index = self.indexes[i % len(self.indexes)]
- entry = self.dataset[index]
+ for j in range(self.batch_size):
+ position = i * self.batch_size + j
+ if position % len(self.indexes) == 0:
+ self.shuffle()
- if entry.cond is None:
- entry.cond_text = self.create_text(entry.filename_text)
+ index = self.indexes[position % len(self.indexes)]
+ entry = self.dataset[index]
- return entry
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
+
+ res.append(entry)
+
+ return res
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..e754747e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
+import csv
from PIL import Image, PngImagePlugin
@@ -172,7 +173,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = step // epoch_len
+ epoch_step = step - epoch * epoch_len
+
+ csv_writer.writerow({
+ "step": step + 1,
+ "epoch": epoch + 1,
+ "epoch_step": epoch_step + 1,
+ **values,
+ })
+
+
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -204,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
@@ -224,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entry in pbar:
+ for i, entries in pbar:
embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step)
@@ -235,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
with torch.autocast("cuda"):
- c = cond_model([entry.cond_text])
-
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
@@ -256,21 +282,37 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
-
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
- height=training_height,
- width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0]
@@ -305,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(entry.cond_text)}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e985242b..2381347f 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -30,8 +30,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr,
- scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None,
+ firstphase_width=firstphase_width if enable_hr else None,
+ firstphase_height=firstphase_height if enable_hr else None,
)
if cmd_opts.enable_console_prompts:
diff --git a/modules/ui.py b/modules/ui.py
index 7446439d..baf4c397 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -22,7 +22,7 @@ import gradio as gr
import gradio.utils
import gradio.routes
-from modules import sd_hijack
+from modules import sd_hijack, sd_models
from modules.paths import script_path
from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
@@ -40,6 +40,7 @@ from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
+import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -432,7 +433,9 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
@@ -446,7 +449,10 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=8):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
with gr.Column(scale=1, elem_id="roll_col"):
sh = gr.Button(elem_id="sh", visible=True)
@@ -508,13 +514,40 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None):
)
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data[key]
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return value
+
+
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
@@ -540,10 +573,11 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- scale_latent = gr.Checkbox(label='Scale latent', value=False)
+ firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
+ firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
- with gr.Row():
+ with gr.Row(equal_height=True):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
@@ -602,8 +636,9 @@ def create_ui(wrap_gradio_gpu_call):
height,
width,
enable_hr,
- scale_latent,
denoising_strength,
+ firstphase_width,
+ firstphase_height,
] + custom_inputs,
outputs=[
txt2img_gallery,
@@ -616,6 +651,17 @@ def create_ui(wrap_gradio_gpu_call):
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
@@ -668,14 +714,29 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (firstphase_width, "First pass size-1"),
+ (firstphase_height, "First pass size-2"),
]
- modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
+ ]
+
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+
with gr.Column(scale=1):
pass
@@ -770,6 +831,17 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+ img2img_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ img2img_prompt_img
+ ],
+ outputs=[
+ img2img_prompt,
+ img2img_prompt_img
+ ]
+ )
+
mask_mode.change(
lambda mode, img: {
init_img_with_mask: gr_show(mode == 0),
@@ -910,7 +982,6 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
]
- modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface:
@@ -958,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+
submit.click(
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
@@ -1017,6 +1089,14 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[image],
outputs=[html, generation_info, html2],
)
+ #images history
+ images_history_switch_dict = {
+ "fn":modules.generation_parameters_copypaste.connect_paste,
+ "t2i":txt2img_paste_fields,
+ "i2i":img2img_paste_fields
+ }
+
+ images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
@@ -1024,11 +1104,12 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
- interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
+ interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
@@ -1091,6 +1172,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@@ -1100,7 +1182,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
- preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
@@ -1169,6 +1251,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_embedding_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
training_width,
@@ -1178,7 +1261,8 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every,
template_file,
save_image_with_stored_embedding,
- preview_image_prompt,
+ preview_from_txt2img,
+ *txt2img_preview_params,
],
outputs=[
ti_output,
@@ -1192,13 +1276,15 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
train_hypernetwork_name,
learn_rate,
+ batch_size,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
- preview_image_prompt,
+ preview_from_txt2img,
+ *txt2img_preview_params,
],
outputs=[
ti_output,
@@ -1410,6 +1496,7 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
+ (images_history, "History", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
@@ -1473,6 +1560,7 @@ Requested path was: {f}
inputs=[
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
interp_method,
interp_amount,
save_as_half,
@@ -1482,6 +1570,7 @@ Requested path was: {f}
submit_result,
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
component_dict['sd_model_checkpoint'],
]
)
@@ -1548,8 +1637,22 @@ Requested path was: {f}
outputs=[extras_image],
)
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img')
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img')
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'sd_model_checkpoint': 'Model hash',
+ }
+
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
+ modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
+
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
@@ -1631,3 +1734,4 @@ if 'gradio_routes_templates_response' not in globals():
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response
+