aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--javascript/contextMenus.js2
-rw-r--r--javascript/hints.js10
-rw-r--r--javascript/notification.js2
-rw-r--r--javascript/ui.js8
-rw-r--r--modules/deepbooru.py73
-rw-r--r--modules/devices.py2
-rw-r--r--modules/extras.py29
-rw-r--r--modules/generation_parameters_copypaste.py9
-rw-r--r--modules/hypernetworks/hypernetwork.py48
-rw-r--r--modules/interrogate.py21
-rw-r--r--modules/processing.py75
-rw-r--r--modules/safe.py9
-rw-r--r--modules/shared.py27
-rw-r--r--modules/textual_inversion/dataset.py47
-rw-r--r--modules/textual_inversion/learn_schedule.py37
-rw-r--r--modules/textual_inversion/preprocess.py94
-rw-r--r--modules/textual_inversion/textual_inversion.py35
-rw-r--r--modules/txt2img.py5
-rw-r--r--modules/ui.py84
-rw-r--r--scripts/img2imgalt.py36
-rw-r--r--scripts/xy_grid.py64
-rw-r--r--style.css23
-rw-r--r--webui.py2
24 files changed, 495 insertions, 248 deletions
diff --git a/.gitignore b/.gitignore
index 7afc9395..69785b3e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,6 +17,7 @@ __pycache__
/webui.settings.bat
/embeddings
/styles.csv
+/params.txt
/styles.csv.bak
/webui-user.bat
/webui-user.sh
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index 7636c4b3..fe67c42e 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -94,7 +94,7 @@ contextMenuInit = function(){
}
gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0]
- if(source.id && source.indexOf('check_progress')>-1){
+ if(source.id && source.id.indexOf('check_progress')>-1){
return
}
diff --git a/javascript/hints.js b/javascript/hints.js
index b81c181b..af010a59 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -14,7 +14,7 @@ titles = {
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
- "\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
+ "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
@@ -81,6 +81,14 @@ titles = {
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
+
+ "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
+ "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
+
+ "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
+
+ "Weighted Sum": "Result = A * (1 - M) + B * M",
+ "Add difference": "Result = A + (B - C) * (1 - M)",
}
diff --git a/javascript/notification.js b/javascript/notification.js
index bdf614ad..f96de313 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -36,7 +36,7 @@ onUiUpdate(function(){
const notification = new Notification(
'Stable Diffusion',
{
- body: `Generated ${imgs.size > 1 ? imgs.size - 1 : 1} image${imgs.size > 1 ? 's' : ''}`,
+ body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
icon: headImg,
image: headImg,
}
diff --git a/javascript/ui.js b/javascript/ui.js
index 4100944e..0f8fe68e 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -33,27 +33,27 @@ function args_to_array(args){
}
function switch_to_txt2img(){
- gradioApp().querySelectorAll('button')[0].click();
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
return args_to_array(arguments);
}
function switch_to_img2img_img2img(){
- gradioApp().querySelectorAll('button')[1].click();
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
return args_to_array(arguments);
}
function switch_to_img2img_inpaint(){
- gradioApp().querySelectorAll('button')[1].click();
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
return args_to_array(arguments);
}
function switch_to_extras(){
- gradioApp().querySelectorAll('button')[2].click();
+ gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
return args_to_array(arguments);
}
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 29529949..f34f3788 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -2,33 +2,46 @@ import os.path
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import time
+import re
+
+re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
- create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process_queue.put(pil_image)
- while shared.deepbooru_process_return["value"] == -1:
- time.sleep(0.2)
- tags = shared.deepbooru_process_return["value"]
- release_process()
- return tags
+
+ try:
+ create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
+ return get_tags_from_process(pil_image)
+ finally:
+ release_process()
-def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort):
+OPT_INCLUDE_RANKS = "include_ranks"
+def create_deepbooru_opts():
+ from modules import shared
+
+ return {
+ "use_spaces": shared.opts.deepbooru_use_spaces,
+ "use_escape": shared.opts.deepbooru_escape,
+ "alpha_sort": shared.opts.deepbooru_sort_alpha,
+ OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
+ }
+
+
+def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
- deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
+ deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
-def create_deepbooru_process(threshold, alpha_sort):
+def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
@@ -41,10 +54,23 @@ def create_deepbooru_process(threshold, alpha_sort):
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort))
+ shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
+def get_tags_from_process(image):
+ from modules import shared
+
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process_queue.put(image)
+ while shared.deepbooru_process_return["value"] == -1:
+ time.sleep(0.2)
+ caption = shared.deepbooru_process_return["value"]
+ shared.deepbooru_process_return["value"] = -1
+
+ return caption
+
+
def release_process():
"""
Stops the deepbooru process to return used memory
@@ -81,10 +107,16 @@ def get_deepbooru_tags_model():
return model, tags
-def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort):
+def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
+
+ alpha_sort = deepbooru_opts['alpha_sort']
+ use_spaces = deepbooru_opts['use_spaces']
+ use_escape = deepbooru_opts['use_escape']
+ include_ranks = deepbooru_opts['include_ranks']
+
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
@@ -122,11 +154,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
if alpha_sort:
sort_ndx = 1
- # sort by reverse by likelihood and normal for alpha
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
- result_tags_out.append(tag)
+ # note: tag_outformat will still have a colon if include_ranks is True
+ tag_outformat = tag.replace(':', ' ')
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
+
+ result_tags_out.append(tag_outformat)
print('\n'.join(sorted(result_tags_print, reverse=True)))
- return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
+ return ', '.join(result_tags_out)
diff --git a/modules/devices.py b/modules/devices.py
index 03ef58f1..eb422583 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -34,7 +34,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
dtype_vae = torch.float16
diff --git a/modules/extras.py b/modules/extras.py
index b24d7de3..532d869f 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -159,48 +159,61 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
+def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
- def weighted_sum(theta0, theta1, alpha):
+ def weighted_sum(theta0, theta1, theta2, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def sigmoid(theta0, theta1, alpha):
+ def sigmoid(theta0, theta1, theta2, alpha):
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def inv_sigmoid(theta0, theta1, alpha):
+ def inv_sigmoid(theta0, theta1, theta2, alpha):
import math
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
return theta0 + ((theta1 - theta0) * alpha)
+ def add_difference(theta0, theta1, theta2, alpha):
+ return theta0 + (theta1 - theta2) * (1.0 - alpha)
+
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
+ teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
-
- theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
+ if teritary_model_info is not None:
+ print(f"Loading {teritary_model_info.filename}...")
+ teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
+ theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
+ else:
+ theta_2 = None
+
theta_funcs = {
"Weighted Sum": weighted_sum,
"Sigmoid": sigmoid,
"Inverse Sigmoid": inv_sigmoid,
+ "Add difference": add_difference,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half:
theta_0[key] = theta_0[key].half()
+ # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
@@ -219,4 +232,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models()
print(f"Checkpoint saved.")
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index ac1ba7f4..c27826b6 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,5 +1,8 @@
+import os
import re
import gradio as gr
+from modules.shared import script_path
+from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code)
@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(script_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
params = parse_generation_parameters(prompt)
res = []
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8314450a..f1248bb7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -14,10 +14,12 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+
def __init__(self, dim, state_dict=None):
super().__init__()
@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
self.to(devices.device)
def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork:
@@ -223,31 +229,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text, cond) in pbar:
+ for i, entry in pbar:
hypernetwork.step = i + ititial_step
- if hypernetwork.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except Exception:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- cond = cond.to(devices.device)
- x = x.to(devices.device)
+ cond = entry.cond.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
@@ -267,7 +265,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -282,16 +280,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
)
processed = processing.process_images(p)
- image = processed.images[0]
+ image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.current_image = image
- image.save(last_saved_image)
-
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -299,7 +297,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 635e266e..9263d65a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -55,7 +55,7 @@ class InterrogateModels:
model, preprocess = clip.load(clip_model_name)
model.eval()
- model = model.to(shared.device)
+ model = model.to(devices.device_interrogate)
return model, preprocess
@@ -65,14 +65,14 @@ class InterrogateModels:
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
- self.blip_model = self.blip_model.to(shared.device)
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
- self.clip_model = self.clip_model.to(shared.device)
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype
@@ -99,11 +99,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
- similarity = torch.zeros((1, len(text_array))).to(shared.device)
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
@@ -116,14 +116,14 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
return caption[0]
- def interrogate(self, pil_image):
+ def interrogate(self, pil_image, include_ranks=False):
res = None
try:
@@ -140,7 +140,7 @@ class InterrogateModels:
res = caption
- clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
@@ -156,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if include_ranks:
+ res += ", " + match
+ else:
+ res += f", ({match}:{score})"
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/processing.py b/modules/processing.py
index 698b3069..100a259f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
assert p.prompt is not None
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
seed = get_fixed_seed(p.seed)
@@ -502,11 +506,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = 0
firstphase_height_truncated = 0
- def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
+ def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=512, firstphase_height=512, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
- self.scale_latent = scale_latent
self.denoising_strength = denoising_strength
+ self.firstphase_width = firstphase_width
+ self.firstphase_height = firstphase_height
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -515,15 +520,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
-
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- self.firstphase_width_truncated = int(scale * self.width)
- self.firstphase_height_truncated = int(scale * self.height)
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -532,39 +528,46 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
- truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
+ truncate_x = 0
+ truncate_y = 0
+ width_ratio = self.width/self.firstphase_width
+ height_ratio = self.height/self.firstphase_height
+
+ if width_ratio > height_ratio:
+ truncate_y = int((self.width - self.firstphase_width) / width_ratio / height_ratio / opt_f)
+ elif width_ratio < height_ratio:
+ truncate_x = int((self.height - self.firstphase_height) / width_ratio / height_ratio / opt_f)
+
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
- if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ decoded_samples = decode_first_stage(self.sd_model, samples)
+
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
- else:
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
-
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
-
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ batch_images = []
+ for i, x_sample in enumerate(lowres_samples):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(0, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
+
+ decoded_samples = torch.from_numpy(np.array(batch_images))
+ decoded_samples = decoded_samples.to(shared.device)
+ decoded_samples = 2. * decoded_samples - 1.
+
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob()
diff --git a/modules/safe.py b/modules/safe.py
index 20be16a5..399165a1 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -96,11 +96,18 @@ def load(filename, *args, **kwargs):
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename)
+ except pickle.UnpicklingError:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ return None
+
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
diff --git a/modules/shared.py b/modules/shared.py
index 42e99741..b6a5c1a8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers
+from modules import sd_samplers, sd_models
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -54,7 +54,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -76,8 +76,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
cmd_opts = parser.parse_args()
-devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
@@ -145,14 +145,14 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = None
- self.show_on_main_page = show_on_main_page
+ self.refresh = refresh
def options_section(section_identifier, options_dict):
@@ -231,11 +231,15 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
- "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -247,16 +251,21 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
+ "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
+ "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -340,6 +349,8 @@ class Options:
item = self.data_labels.get(key)
item.onchange = func
+ func()
+
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index f61f40d3..67e90afe 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared
import re
-re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token
@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception:
continue
+ text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
- filename_tokens = os.path.splitext(filename)[0]
- filename_tokens = re_tag.findall(filename_tokens)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
+
if include_cond:
- text = self.create_text(filename_tokens)
- cond = cond_model([text]).to(devices.cpu)
- else:
- cond = None
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens, cond))
+ self.dataset.append(entry)
self.length = len(self.dataset) * repeats
@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
- def create_text(self, filename_tokens):
+ def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens, cond = self.dataset[index]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
- text = self.create_text(filename_tokens)
- return x, text, cond
+ return entry
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index db720271..2062726a 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -1,6 +1,12 @@
+import tqdm
-class LearnSchedule:
+
+class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
@@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1]
else:
raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 113cecf1..886cf0c3 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -10,7 +10,30 @@ from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
+
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -25,30 +48,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- if process_caption:
- shared.interrogator.load()
-
- if process_caption_deepbooru:
- deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
-
def save_pic_with_caption(image, index):
+ caption = ""
+
if process_caption:
- caption = "-" + shared.interrogator.generate_caption(image)
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- elif process_caption_deepbooru:
- shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process_queue.put(image)
- while shared.deepbooru_process_return["value"] == -1:
- time.sleep(0.2)
- caption = "-" + shared.deepbooru_process_return["value"]
- caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
- shared.deepbooru_process_return["value"] = -1
- else:
- caption = filename
- caption = os.path.splitext(caption)[0]
- caption = os.path.basename(caption)
+ caption += shared.interrogator.generate_caption(image)
+
+ if process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = filename
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{subindex[0]}-{filename_part}"
+ image.save(os.path.join(dst, f"{basename}.png"))
+
+ if len(caption) > 0:
+ with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
- image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
def save_pic(image, index):
@@ -93,34 +114,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
save_pic(img, index)
shared.state.nextjob()
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
- if process_caption_deepbooru:
- deepbooru.release_process()
-
-
-def sanitize_caption(base_path, original_caption, suffix):
- operating_system = platform.system().lower()
- if (operating_system == "windows"):
- invalid_path_characters = "\\/:*?\"<>|"
- max_path_length = 259
- else:
- invalid_path_characters = "/" #linux/macos
- max_path_length = 1023
- caption = original_caption
- for invalid_character in invalid_path_characters:
- caption = caption.replace(invalid_character, "")
- fixed_path_length = len(base_path) + len(suffix)
- if fixed_path_length + len(caption) <= max_path_length:
- return caption
- caption_tokens = caption.split()
- new_caption = ""
- for token in caption_tokens:
- last_caption = new_caption
- new_caption = new_caption + token + " "
- if (len(new_caption) + fixed_path_length - 1 > max_path_length):
- break
- print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
- return last_caption.strip()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c5153e4a..fa0e33a2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
@@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text, _) in pbar:
+ for i, entry in pbar:
embedding.step = i + ititial_step
- if embedding.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
+ c = cond_model([entry.cond_text])
- x = x.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
@@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e985242b..2381347f 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -30,8 +30,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr,
- scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None,
+ firstphase_width=firstphase_width if enable_hr else None,
+ firstphase_height=firstphase_height if enable_hr else None,
)
if cmd_opts.enable_console_prompts:
diff --git a/modules/ui.py b/modules/ui.py
index dd793c39..0a3ee887 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -78,6 +78,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
+
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -538,10 +540,11 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- scale_latent = gr.Checkbox(label='Scale latent', value=False)
+ firstphase_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512)
+ firstphase_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
- with gr.Row():
+ with gr.Row(equal_height=True):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
@@ -600,8 +603,9 @@ def create_ui(wrap_gradio_gpu_call):
height,
width,
enable_hr,
- scale_latent,
denoising_strength,
+ firstphase_width,
+ firstphase_height,
] + custom_inputs,
outputs=[
txt2img_gallery,
@@ -678,6 +682,8 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (firstphase_width, "First pass size-1"),
+ (firstphase_height, "First pass size-2"),
]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
@@ -1050,11 +1056,12 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
- interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3)
+ interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
@@ -1102,11 +1109,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two')
- process_caption = gr.Checkbox(label='Use BLIP caption as filename')
- if cmd_opts.deepdanbooru:
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
- else:
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
+ process_caption = gr.Checkbox(label='Use BLIP for caption')
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
with gr.Row():
with gr.Column(scale=3):
@@ -1126,7 +1130,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
- num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
@@ -1204,7 +1207,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width,
training_height,
steps,
- num_repeats,
create_image_every,
save_embedding_every,
template_file,
@@ -1243,8 +1245,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
-
- def create_setting_component(key):
+ def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1264,7 +1265,34 @@ def create_ui(wrap_gradio_gpu_call):
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return comp(label=info.label, value=fun, **(args or {}))
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
+ else:
+ with gr.Row(variant="compact"):
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
+
+ def refresh():
+ info.refresh()
+ refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
+
+ for k, v in refreshed_args.items():
+ setattr(res, k, v)
+
+ return gr.update(**(refreshed_args or {}))
+
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[res],
+ )
+ else:
+ res = comp(label=info.label, value=fun, **(args or {}))
+
+
+ return res
components = []
component_dict = {}
@@ -1338,6 +1366,9 @@ Requested path was: {f}
settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+
quicksettings_list = []
cols_displayed = 0
@@ -1362,7 +1393,7 @@ Requested path was: {f}
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- if item.show_on_main_page:
+ if k in quicksettings_names:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1371,7 +1402,11 @@ Requested path was: {f}
components.append(component)
items_displayed += 1
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ with gr.Row():
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+
request_notifications.click(
fn=lambda: None,
inputs=[],
@@ -1379,10 +1414,6 @@ Requested path was: {f}
_js='function(){}'
)
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
-
def reload_scripts():
modules.scripts.reload_script_body_only()
@@ -1397,7 +1428,6 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
-
restart_gradio.click(
fn=request_restart,
inputs=[],
@@ -1432,12 +1462,12 @@ Requested path was: {f}
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
- component = create_setting_component(k)
+ component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
settings_interface.gradio_ref = demo
- with gr.Tabs() as tabs:
+ with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
@@ -1476,6 +1506,7 @@ Requested path was: {f}
inputs=[
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
interp_method,
interp_amount,
save_as_half,
@@ -1485,6 +1516,7 @@ Requested path was: {f}
submit_result,
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
component_dict['sd_model_checkpoint'],
]
)
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
index 313a55d2..d438175c 100644
--- a/scripts/img2imgalt.py
+++ b/scripts/img2imgalt.py
@@ -120,15 +120,45 @@ class Script(scripts.Script):
return is_img2img
def ui(self, is_img2img):
+ info = gr.Markdown('''
+ * `CFG Scale` should be 2 or lower.
+ ''')
+
+ override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True)
+
+ override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True)
original_prompt = gr.Textbox(label="Original prompt", lines=1)
original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1)
- cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
+
+ override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
+
+ override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True)
+
+ cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False)
- return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
- def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
+ return [
+ info,
+ override_sampler,
+ override_prompt, original_prompt, original_negative_prompt,
+ override_steps, st,
+ override_strength,
+ cfg, randomness, sigma_adjustment,
+ ]
+
+ def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
+ # Override
+ if override_sampler:
+ p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
+ if override_prompt:
+ p.prompt = original_prompt
+ p.negative_prompt = original_negative_prompt
+ if override_steps:
+ p.steps = st
+ if override_strength:
+ p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 3bb080bf..efb63af5 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs):
hypernetwork.load_hypernetwork(name)
+def apply_hypernetwork_strength(p, x, xs):
+ hypernetwork.apply_strength(x)
+
+
def confirm_hypernetworks(p, xs):
for x in xs:
if x.lower() in ["", "none"]:
@@ -165,6 +169,7 @@ axis_options = [
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
+ AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
@@ -175,13 +180,17 @@ axis_options = [
]
-def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
- res = []
-
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
- first_processed = None
+ # Temporary list of all the images that are generated to be populated into the grid.
+ # Will be filled with empty images for any individual step that fails to process properly
+ image_cache = []
+
+ processed_result = None
+ cell_mode = "P"
+ cell_size = (1,1)
state.job_count = len(xs) * len(ys) * p.n_iter
@@ -189,22 +198,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
for ix, x in enumerate(xs):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
- processed = cell(x, y)
- if first_processed is None:
- first_processed = processed
-
+ processed:Processed = cell(x, y)
try:
- res.append(processed.images[0])
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+
+ image_cache.append(processed_image)
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
except:
- res.append(Image.new(res[0].mode, res[0].size))
+ image_cache.append(Image.new(cell_mode, cell_size))
+
+ if not processed_result:
+ print("Unexpected error: draw_xy_grid failed to return even a single processed image")
+ return Processed()
- grid = images.image_grid(res, rows=len(ys))
+ grid = images.image_grid(image_cache, rows=len(ys))
if draw_legend:
- grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
- first_processed.images = [grid]
+ processed_result.images[0] = grid
- return first_processed
+ return processed_result
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
@@ -229,11 +255,12 @@ class Script(scripts.Script):
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
draw_legend = gr.Checkbox(label='Draw legend', value=True)
+ include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
- return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds]
+ return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
- def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds):
+ def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -344,7 +371,8 @@ class Script(scripts.Script):
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell,
- draw_legend=draw_legend
+ draw_legend=draw_legend,
+ include_lone_images=include_lone_images
)
if opts.grid_save:
@@ -354,6 +382,8 @@ class Script(scripts.Script):
modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+ hypernetwork.apply_strength()
+
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
diff --git a/style.css b/style.css
index e6fa10b4..aa3d379c 100644
--- a/style.css
+++ b/style.css
@@ -228,6 +228,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
border-top: 1px solid #eee;
border-left: 1px solid #eee;
border-right: 1px solid #eee;
+
+ z-index: 300;
}
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
@@ -480,16 +482,30 @@ input[type="range"]{
background: #a55000;
}
+#quicksettings {
+ gap: 0.4em;
+}
+
#quicksettings > div{
border: none;
background: none;
+ flex: unset;
+ gap: 0.5em;
}
#quicksettings > div > div{
max-width: 32em;
+ min-width: 24em;
padding: 0;
}
+#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{
+ max-width: 2.5em;
+ min-width: 2.5em;
+ height: 2.4em;
+}
+
+
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
@@ -506,3 +522,10 @@ canvas[key="mask"] {
z-index: 200;
width: 8em;
}
+#quicksettings .gr-box > div > div > input.gr-text-input {
+ top: -1.12em;
+}
+
+.row.gr-compact{
+ overflow: visible;
+}
diff --git a/webui.py b/webui.py
index 33ba7905..fe0ce321 100644
--- a/webui.py
+++ b/webui.py
@@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
-
def initialize():
modelloader.cleanup_models()
modules.sd_models.setup_model()
@@ -86,6 +85,7 @@ def initialize():
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
+ shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui():