aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/deepbooru.py73
-rw-r--r--modules/devices.py2
-rw-r--r--modules/extras.py29
-rw-r--r--modules/generation_parameters_copypaste.py9
-rw-r--r--modules/hypernetworks/hypernetwork.py48
-rw-r--r--modules/interrogate.py21
-rw-r--r--modules/processing.py75
-rw-r--r--modules/safe.py9
-rw-r--r--modules/shared.py27
-rw-r--r--modules/textual_inversion/dataset.py47
-rw-r--r--modules/textual_inversion/learn_schedule.py37
-rw-r--r--modules/textual_inversion/preprocess.py94
-rw-r--r--modules/textual_inversion/textual_inversion.py35
-rw-r--r--modules/txt2img.py5
-rw-r--r--modules/ui.py84
15 files changed, 375 insertions, 220 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 29529949..f34f3788 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -2,33 +2,46 @@ import os.path
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import time
+import re
+
+re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
- create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process_queue.put(pil_image)
- while shared.deepbooru_process_return["value"] == -1:
- time.sleep(0.2)
- tags = shared.deepbooru_process_return["value"]
- release_process()
- return tags
+
+ try:
+ create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
+ return get_tags_from_process(pil_image)
+ finally:
+ release_process()
-def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort):
+OPT_INCLUDE_RANKS = "include_ranks"
+def create_deepbooru_opts():
+ from modules import shared
+
+ return {
+ "use_spaces": shared.opts.deepbooru_use_spaces,
+ "use_escape": shared.opts.deepbooru_escape,
+ "alpha_sort": shared.opts.deepbooru_sort_alpha,
+ OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
+ }
+
+
+def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
- deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
+ deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
-def create_deepbooru_process(threshold, alpha_sort):
+def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
@@ -41,10 +54,23 @@ def create_deepbooru_process(threshold, alpha_sort):
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort))
+ shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
+def get_tags_from_process(image):
+ from modules import shared
+
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process_queue.put(image)
+ while shared.deepbooru_process_return["value"] == -1:
+ time.sleep(0.2)
+ caption = shared.deepbooru_process_return["value"]
+ shared.deepbooru_process_return["value"] = -1
+
+ return caption
+
+
def release_process():
"""
Stops the deepbooru process to return used memory
@@ -81,10 +107,16 @@ def get_deepbooru_tags_model():
return model, tags
-def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort):
+def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
+
+ alpha_sort = deepbooru_opts['alpha_sort']
+ use_spaces = deepbooru_opts['use_spaces']
+ use_escape = deepbooru_opts['use_escape']
+ include_ranks = deepbooru_opts['include_ranks']
+
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
@@ -122,11 +154,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
if alpha_sort:
sort_ndx = 1
- # sort by reverse by likelihood and normal for alpha
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
- result_tags_out.append(tag)
+ # note: tag_outformat will still have a colon if include_ranks is True
+ tag_outformat = tag.replace(':', ' ')
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
+
+ result_tags_out.append(tag_outformat)
print('\n'.join(sorted(result_tags_print, reverse=True)))
- return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
+ return ', '.join(result_tags_out)
diff --git a/modules/devices.py b/modules/devices.py
index 03ef58f1..eb422583 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -34,7 +34,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
dtype_vae = torch.float16
diff --git a/modules/extras.py b/modules/extras.py
index b24d7de3..532d869f 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -159,48 +159,61 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
+def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
- def weighted_sum(theta0, theta1, alpha):
+ def weighted_sum(theta0, theta1, theta2, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def sigmoid(theta0, theta1, alpha):
+ def sigmoid(theta0, theta1, theta2, alpha):
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def inv_sigmoid(theta0, theta1, alpha):
+ def inv_sigmoid(theta0, theta1, theta2, alpha):
import math
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
return theta0 + ((theta1 - theta0) * alpha)
+ def add_difference(theta0, theta1, theta2, alpha):
+ return theta0 + (theta1 - theta2) * (1.0 - alpha)
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
+ teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
-
- theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
+ if teritary_model_info is not None:
+ print(f"Loading {teritary_model_info.filename}...")
+ teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
+ theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
+ else:
+ theta_2 = None
+
theta_funcs = {
"Weighted Sum": weighted_sum,
"Sigmoid": sigmoid,
"Inverse Sigmoid": inv_sigmoid,
+ "Add difference": add_difference,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half:
theta_0[key] = theta_0[key].half()
+ # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
@@ -219,4 +232,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models()
print(f"Checkpoint saved.")
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index ac1ba7f4..c27826b6 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,5 +1,8 @@
+import os
import re
import gradio as gr
+from modules.shared import script_path
+from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code)
@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(script_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
params = parse_generation_parameters(prompt)
res = []
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8314450a..f1248bb7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -14,10 +14,12 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+
def __init__(self, dim, state_dict=None):
super().__init__()
@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
self.to(devices.device)
def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork:
@@ -223,31 +229,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text, cond) in pbar:
+ for i, entry in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except Exception:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- cond = cond.to(devices.device)
- x = x.to(devices.device)
+ cond = entry.cond.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
@@ -267,7 +265,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -282,16 +280,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
)
processed = processing.process_images(p)
- image = processed.images[0]
+ image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.current_image = image
- image.save(last_saved_image)
-
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -299,7 +297,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 635e266e..9263d65a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -55,7 +55,7 @@ class InterrogateModels:
model, preprocess = clip.load(clip_model_name)
model.eval()
- model = model.to(shared.device)
+ model = model.to(devices.device_interrogate)
return model, preprocess
@@ -65,14 +65,14 @@ class InterrogateModels:
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
- self.blip_model = self.blip_model.to(shared.device)
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
- self.clip_model = self.clip_model.to(shared.device)
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype
@@ -99,11 +99,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
- similarity = torch.zeros((1, len(text_array))).to(shared.device)
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
@@ -116,14 +116,14 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
return caption[0]
- def interrogate(self, pil_image):
+ def interrogate(self, pil_image, include_ranks=False):
res = None
try:
@@ -140,7 +140,7 @@ class InterrogateModels:
res = caption
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
@@ -156,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if include_ranks:
+ res += ", " + match
+ else:
+ res += f", ({match}:{score})"
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/processing.py b/modules/processing.py
index 698b3069..100a259f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
assert p.prompt is not None
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
seed = get_fixed_seed(p.seed)
@@ -502,11 +506,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = 0
firstphase_height_truncated = 0
- def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
+ def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=512, firstphase_height=512, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
- self.scale_latent = scale_latent
self.denoising_strength = denoising_strength
+ self.firstphase_width = firstphase_width
+ self.firstphase_height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -515,15 +520,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
-
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- self.firstphase_width_truncated = int(scale * self.width)
- self.firstphase_height_truncated = int(scale * self.height)
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -532,39 +528,46 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
- truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
+ truncate_x = 0
+ truncate_y = 0
+ width_ratio = self.width/self.firstphase_width
+ height_ratio = self.height/self.firstphase_height
+
+ if width_ratio > height_ratio:
+ truncate_y = int((self.width - self.firstphase_width) / width_ratio / height_ratio / opt_f)
+ elif width_ratio < height_ratio:
+ truncate_x = int((self.height - self.firstphase_height) / width_ratio / height_ratio / opt_f)
+
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
- if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ decoded_samples = decode_first_stage(self.sd_model, samples)
+
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
- else:
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
-
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
-
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ batch_images = []
+ for i, x_sample in enumerate(lowres_samples):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(0, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
+
+ decoded_samples = torch.from_numpy(np.array(batch_images))
+ decoded_samples = decoded_samples.to(shared.device)
+ decoded_samples = 2. * decoded_samples - 1.
+
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob()
diff --git a/modules/safe.py b/modules/safe.py
index 20be16a5..399165a1 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -96,11 +96,18 @@ def load(filename, *args, **kwargs):
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename)
+ except pickle.UnpicklingError:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ return None
+
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
diff --git a/modules/shared.py b/modules/shared.py
index 42e99741..b6a5c1a8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers
+from modules import sd_samplers, sd_models
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -54,7 +54,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -76,8 +76,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
cmd_opts = parser.parse_args()
-devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
@@ -145,14 +145,14 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = None
- self.show_on_main_page = show_on_main_page
+ self.refresh = refresh
def options_section(section_identifier, options_dict):
@@ -231,11 +231,15 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
- "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -247,16 +251,21 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
+ "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
+ "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -340,6 +349,8 @@ class Options:
item = self.data_labels.get(key)
item.onchange = func
+ func()
+
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index f61f40d3..67e90afe 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception:
continue
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
+
if include_cond:
- text = self.create_text(filename_tokens)
- cond = cond_model([text]).to(devices.cpu)
- else:
- cond = None
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens, cond))
+ self.dataset.append(entry)
self.length = len(self.dataset) * repeats
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
- def create_text(self, filename_tokens):
+ def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens, cond = self.dataset[index]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
- text = self.create_text(filename_tokens)
- return x, text, cond
+ return entry
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index db720271..2062726a 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -1,6 +1,12 @@
+import tqdm
-class LearnSchedule:
+
+class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
@@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1]
else:
raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 113cecf1..886cf0c3 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -10,7 +10,30 @@ from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
+
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -25,30 +48,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- if process_caption:
- shared.interrogator.load()
-
- if process_caption_deepbooru:
- deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
-
def save_pic_with_caption(image, index):
+ caption = ""
+
if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- elif process_caption_deepbooru:
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process_queue.put(image)
- while shared.deepbooru_process_return["value"] == -1:
- time.sleep(0.2)
- caption = "-" + shared.deepbooru_process_return["value"]
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- shared.deepbooru_process_return["value"] = -1
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
+ caption += shared.interrogator.generate_caption(image)
+
+ if process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = filename
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{subindex[0]}-{filename_part}"
+ image.save(os.path.join(dst, f"{basename}.png"))
+
+ if len(caption) > 0:
+ with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
def save_pic(image, index):
@@ -93,34 +114,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
save_pic(img, index)
shared.state.nextjob()
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
- if process_caption_deepbooru:
- deepbooru.release_process()
-
-
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
- else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
- break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c5153e4a..fa0e33a2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
@@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text, _) in pbar:
+ for i, entry in pbar:
embedding.step = i + ititial_step
- if embedding.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
+ c = cond_model([entry.cond_text])
- x = x.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
@@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e985242b..2381347f 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -30,8 +30,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr,
- scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None,
+ firstphase_width=firstphase_width if enable_hr else None,
+ firstphase_height=firstphase_height if enable_hr else None,
)
if cmd_opts.enable_console_prompts:
diff --git a/modules/ui.py b/modules/ui.py
index dd793c39..0a3ee887 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -78,6 +78,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
+
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -538,10 +540,11 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- scale_latent = gr.Checkbox(label='Scale latent', value=False)
+ firstphase_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512)
+ firstphase_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
- with gr.Row():
+ with gr.Row(equal_height=True):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
@@ -600,8 +603,9 @@ def create_ui(wrap_gradio_gpu_call):
height,
width,
enable_hr,
- scale_latent,
denoising_strength,
+ firstphase_width,
+ firstphase_height,
] + custom_inputs,
outputs=[
txt2img_gallery,
@@ -678,6 +682,8 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (firstphase_width, "First pass size-1"),
+ (firstphase_height, "First pass size-2"),
]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
@@ -1050,11 +1056,12 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
- interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3)
+ interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
@@ -1102,11 +1109,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two')
- process_caption = gr.Checkbox(label='Use BLIP caption as filename')
- if cmd_opts.deepdanbooru:
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
- else:
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
+ process_caption = gr.Checkbox(label='Use BLIP for caption')
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
with gr.Row():
with gr.Column(scale=3):
@@ -1126,7 +1130,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
- num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
@@ -1204,7 +1207,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width,
training_height,
steps,
- num_repeats,
create_image_every,
save_embedding_every,
template_file,
@@ -1243,8 +1245,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
-
- def create_setting_component(key):
+ def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1264,7 +1265,34 @@ def create_ui(wrap_gradio_gpu_call):
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return comp(label=info.label, value=fun, **(args or {}))
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
+ else:
+ with gr.Row(variant="compact"):
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
+
+ def refresh():
+ info.refresh()
+ refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
+
+ for k, v in refreshed_args.items():
+ setattr(res, k, v)
+
+ return gr.update(**(refreshed_args or {}))
+
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[res],
+ )
+ else:
+ res = comp(label=info.label, value=fun, **(args or {}))
+
+
+ return res
components = []
component_dict = {}
@@ -1338,6 +1366,9 @@ Requested path was: {f}
settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+
quicksettings_list = []
cols_displayed = 0
@@ -1362,7 +1393,7 @@ Requested path was: {f}
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- if item.show_on_main_page:
+ if k in quicksettings_names:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1371,7 +1402,11 @@ Requested path was: {f}
components.append(component)
items_displayed += 1
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ with gr.Row():
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+
request_notifications.click(
fn=lambda: None,
inputs=[],
@@ -1379,10 +1414,6 @@ Requested path was: {f}
_js='function(){}'
)
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
-
def reload_scripts():
modules.scripts.reload_script_body_only()
@@ -1397,7 +1428,6 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
-
restart_gradio.click(
fn=request_restart,
inputs=[],
@@ -1432,12 +1462,12 @@ Requested path was: {f}
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
- component = create_setting_component(k)
+ component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
settings_interface.gradio_ref = demo
- with gr.Tabs() as tabs:
+ with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
@@ -1476,6 +1506,7 @@ Requested path was: {f}
inputs=[
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
interp_method,
interp_amount,
save_as_half,
@@ -1485,6 +1516,7 @@ Requested path was: {f}
submit_result,
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
component_dict['sd_model_checkpoint'],
]
)