aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWDevelopsWebApps <97454358+WDevelopsWebApps@users.noreply.github.com>2022-09-29 12:19:13 +0200
committerGitHub <noreply@github.com>2022-09-29 12:19:13 +0200
commitf28ce3e3a17ccd9b4a03317031a4e3caa1a3088f (patch)
tree9f57cde73305695cce558cd8a172b4974a02ee1d
parent03ee67bfd34b9e872b33eb05fef5db83410b16f3 (diff)
parentbe5555fce4612fdfb4a06e831e3f1a8d055fdf9a (diff)
Merge branch 'master' into saving
-rw-r--r--.gitignore2
-rw-r--r--javascript/hints.js1
-rw-r--r--javascript/ui.js19
-rw-r--r--launch.py8
-rw-r--r--modules/extras.py46
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/paths.py1
-rw-r--r--modules/processing.py20
-rw-r--r--modules/prompt_parser.py90
-rw-r--r--modules/scripts.py2
-rw-r--r--modules/sd_hijack.py125
-rw-r--r--modules/sd_models.py14
-rw-r--r--modules/sd_samplers.py83
-rw-r--r--modules/shared.py12
-rw-r--r--modules/ui.py128
-rw-r--r--requirements.txt10
-rw-r--r--requirements_versions.txt9
-rw-r--r--scripts/xy_grid.py12
-rw-r--r--style.css16
-rw-r--r--webui.py5
20 files changed, 473 insertions, 132 deletions
diff --git a/.gitignore b/.gitignore
index 69ea78c5..b71e1875 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,7 +4,7 @@ __pycache__
/venv
/tmp
/model.ckpt
-/models/*.ckpt
+/models/**/*.ckpt
/GFPGANv1.3.pth
/gfpgan/weights/*.pth
/ui-config.json
diff --git a/javascript/hints.js b/javascript/hints.js
index 59dd770c..96cd24d5 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -15,6 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
+ "\uD83D\uDCC2": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
diff --git a/javascript/ui.js b/javascript/ui.js
index 7db4db48..562d2552 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -182,4 +182,23 @@ onUiUpdate(function(){
});
json_elem.parentElement.style.display="none"
+
+ if (!txt2img_textarea) {
+ txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
+ txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
+ }
+ if (!img2img_textarea) {
+ img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
+ img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
+ }
})
+
+let txt2img_textarea, img2img_textarea = undefined;
+let wait_time = 800
+let token_timeout;
+
+function update_token_counter(button_id) {
+ if (token_timeout)
+ clearTimeout(token_timeout);
+ token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+}
diff --git a/launch.py b/launch.py
index 58e28f94..0e6b64ab 100644
--- a/launch.py
+++ b/launch.py
@@ -15,14 +15,14 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
-k_diffusion_package = os.environ.get('K_DIFFUSION_PACKAGE', "git+https://github.com/crowsonkb/k-diffusion.git@1a0703dfb7d24d8806267c3e7ccc4caf67fd1331")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
+k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
-ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af")
+ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args)
@@ -110,9 +110,6 @@ if not is_installed("torch") or not is_installed("torchvision"):
if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
-if not is_installed("k_diffusion.sampling"):
- run_pip(f"install {k_diffusion_package}", "k-diffusion")
-
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
@@ -120,6 +117,7 @@ os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
+git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
diff --git a/modules/extras.py b/modules/extras.py
index b8ebc619..c2543fcf 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -6,13 +6,14 @@ from PIL import Image
import torch
import tqdm
-from modules import processing, shared, images, devices
+from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
import piexif
import piexif.helper
+import gradio as gr
cached_images = {}
@@ -140,7 +141,7 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount):
+def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -150,23 +151,20 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
- if os.path.exists(primary_model_name):
- primary_model_filename = primary_model_name
- primary_model_name = os.path.splitext(os.path.basename(primary_model_name))[0]
- else:
- primary_model_filename = 'models/' + primary_model_name + '.ckpt'
+ # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
+ def inv_sigmoid(theta0, theta1, alpha):
+ import math
+ alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
+ return theta0 + ((theta1 - theta0) * alpha)
- if os.path.exists(secondary_model_name):
- secondary_model_filename = secondary_model_name
- secondary_model_name = os.path.splitext(os.path.basename(secondary_model_name))[0]
- else:
- secondary_model_filename = 'models/' + secondary_model_name + '.ckpt'
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- print(f"Loading {primary_model_filename}...")
- primary_model = torch.load(primary_model_filename, map_location='cpu')
+ print(f"Loading {primary_model_info.filename}...")
+ primary_model = torch.load(primary_model_info.filename, map_location='cpu')
- print(f"Loading {secondary_model_filename}...")
- secondary_model = torch.load(secondary_model_filename, map_location='cpu')
+ print(f"Loading {secondary_model_info.filename}...")
+ secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = primary_model['state_dict']
theta_1 = secondary_model['state_dict']
@@ -174,21 +172,31 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
theta_funcs = {
"Weighted Sum": weighted_sum,
"Sigmoid": sigmoid,
+ "Inverse Sigmoid": inv_sigmoid,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ if save_as_half:
+ theta_0[key] = theta_0[key].half()
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
+ if save_as_half:
+ theta_0[key] = theta_0[key].half()
+
+ filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
+ filename = filename if custom_name == '' else (custom_name + '.ckpt')
+ output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
- output_modelname = 'models/' + primary_model_name + '_' + str(round(interp_amount,2)) + '-' + secondary_model_name + '_' + str(round((float(1.0) - interp_amount),2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
+ sd_models.list_models()
+
print(f"Checkpoint saved.")
- return "Checkpoint saved to " + output_modelname \ No newline at end of file
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
diff --git a/modules/img2img.py b/modules/img2img.py
index d80b3e75..03e934e9 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -124,4 +124,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout:
print(generation_info_js)
- return processed.images, generation_info_js, plaintext_to_html(processed.info) \ No newline at end of file
+ return processed.images, generation_info_js, plaintext_to_html(processed.info)
diff --git a/modules/paths.py b/modules/paths.py
index 3a19f9e5..df7b9d9a 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -20,6 +20,7 @@ path_dirs = [
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
+ (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
]
paths = {}
diff --git a/modules/processing.py b/modules/processing.py
index 90e00bf8..4ecdfcd2 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -49,7 +49,7 @@ def apply_color_correction(correction, image):
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -75,15 +75,15 @@ class StableDiffusionProcessing:
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images
+ self.eta = eta
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
-
- self.eta = opts.eta
+
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
- self.s_tmax = float('inf') # not representable as a standard ui option
+ self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise
if not seed_enable_extras:
@@ -100,7 +100,7 @@ class StableDiffusionProcessing:
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -139,6 +139,7 @@ class Processed:
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed]
+ self.infotexts = infotexts or [info]
def js(self):
obj = {
@@ -165,6 +166,7 @@ class Processed:
"denoising_strength": self.denoising_strength,
"extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image,
+ "infotexts": self.infotexts,
}
return json.dumps(obj)
@@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
+ "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
@@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
+ return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ infotexts = []
output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
@@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ infotexts.append(infotext(n, i))
output_images.append(image)
state.nextjob()
@@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
+ infotexts.insert(0, infotext())
output_images.insert(0, grid)
index_of_first_image = 1
@@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
+ return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index a6a25b28..e811eb9e 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res
+re_attention = re.compile(r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""", re.X)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ Example:
+
+ 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
+
+ produces:
+
+ [
+ ['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]
+ ]
+ """
-#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith('\\'):
+ res.append([text[1:], 1.0])
+ elif text == '(':
+ round_brackets.append(len(res))
+ elif text == '[':
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ')' and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == ']' and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ return res
diff --git a/modules/scripts.py b/modules/scripts.py
index 202374e6..7c3bd5e7 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -55,7 +55,7 @@ def load_scripts(basedir):
if not os.path.exists(basedir):
return
- for filename in os.listdir(basedir):
+ for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename)
if not os.path.isfile(path):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 7b2030d4..2848a251 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -6,6 +6,7 @@ import torch
import numpy as np
from torch import einsum
+from modules import prompt_parser
from modules.shared import opts, device, cmd_opts
from ldm.util import default
@@ -180,6 +181,7 @@ class StableDiffusionModelHijack:
dir_mtime = None
layers = None
circular_enabled = False
+ clip = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
@@ -210,6 +212,7 @@ class StableDiffusionModelHijack:
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
@@ -235,13 +238,14 @@ class StableDiffusionModelHijack:
print(traceback.format_exc(), file=sys.stderr)
continue
- print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
@@ -268,6 +272,11 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
+ def tokenize(self, text):
+ max_length = self.clip.max_length - 2
+ _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+ return remade_batch_tokens[0], token_count, max_length
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
@@ -294,14 +303,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
- def forward(self, text):
- self.hijack.fixes = []
- self.hijack.comments = []
+
+ def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ id_start = self.wrapped.tokenizer.bos_token_id
+ id_end = self.wrapped.tokenizer.eos_token_id
+ maxlen = self.wrapped.max_length
+
+ if opts.enable_emphasis:
+ parsed = prompt_parser.parse_prompt_attention(line)
+ else:
+ parsed = [[line, 1.0]]
+
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
+
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ possible_matches = self.hijack.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ else:
+ found = False
+ for ids, word in possible_matches:
+ if tokens[i:i + len(ids)] == ids:
+ emb_len = int(self.hijack.word_embeddings[word].shape[0])
+ fixes.append((len(remade_tokens), word))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ i += len(ids) - 1
+ found = True
+ used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
+ break
+
+ if not found:
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ i += 1
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ return remade_tokens, fixes, multipliers, token_count
+
+ def process_text(self, texts):
+ used_custom_terms = []
remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_multipliers = []
+ for line in texts:
+ if line in cache:
+ remade_tokens, fixes, multipliers = cache[line]
+ else:
+ remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+
+ cache[line] = (remade_tokens, fixes, multipliers)
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+ def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
used_custom_terms = []
+ remade_batch_tokens = []
+ overflowing_words = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@@ -353,9 +449,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
-
- self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@@ -364,11 +459,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
- self.hijack.fixes.append(fixes)
+ hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+ def forward(self, text):
+
+ if opts.use_old_emphasis_implementation:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+ else:
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+
+
+ self.hijack.fixes = hijack_fixes
+ self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 9decc911..7a5edced 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -23,6 +23,10 @@ except Exception:
pass
+def checkpoint_tiles():
+ return sorted([x.title for x in checkpoints_list.values()])
+
+
def list_models():
checkpoints_list.clear()
@@ -39,13 +43,14 @@ def list_models():
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
- return f'{name} [{h}]'
+ shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+
+ return f'{name} [{h}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
- title = modeltitle(cmd_ckpt, h)
- model_name = title.rsplit(".",1)[0] # remove extension if present
+ title, model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
@@ -53,8 +58,7 @@ def list_models():
if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
- title = modeltitle(filename, h)
- model_name = title.rsplit(".",1)[0] # remove extension if present
+ title, model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index a1183997..fc0c94b4 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -40,10 +40,8 @@ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_euler_ancestral': ['eta'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_dpm_2_ancestral': ['eta'],
}
def setup_img2img_steps(p, steps=None):
@@ -101,6 +99,8 @@ class VanillaStableDiffusionSampler:
self.init_latent = None
self.sampler_noises = None
self.step = 0
+ self.eta = None
+ self.default_eta = 0.0
def number_of_needed_noises(self, p):
return 0
@@ -123,20 +123,29 @@ class VanillaStableDiffusionSampler:
self.step += 1
return res
+ def initialize(self, p):
+ self.eta = p.eta or opts.eta_ddim
+
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
+ if hasattr(self.sampler, fieldname):
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
+
+ self.mask = p.mask if hasattr(p, 'mask') else None
+ self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps)
+ self.initialize(p)
+
# existing code fails with cetain step counts, like 9
try:
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
- self.sampler.p_sample_ddim = self.p_sample_ddim_hook
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
self.init_latent = x
self.step = 0
@@ -145,11 +154,8 @@ class VanillaStableDiffusionSampler:
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
- self.mask = None
- self.nmask = None
+ self.initialize(p)
+
self.init_latent = None
self.step = 0
@@ -157,9 +163,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
+ samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
+ samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
return samples_ddim
@@ -237,6 +243,8 @@ class KDiffusionSampler:
self.sampler_noises = None
self.sampler_noise_index = 0
self.stop_at = None
+ self.eta = None
+ self.default_eta = 1.0
def callback_state(self, d):
store_latent(d["denoised"])
@@ -255,22 +263,12 @@ class KDiffusionSampler:
self.sampler_noise_index += 1
return res
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
- steps, t_enc = setup_img2img_steps(p, steps)
-
- sigmas = self.model_wrap.get_sigmas(steps)
-
- noise = noise * sigmas[steps - t_enc - 1]
-
- xi = x + noise
-
- sigma_sched = sigmas[steps - t_enc - 1:]
-
+ def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.init_latent = x
self.model_wrap.step = 0
self.sampler_noise_index = 0
+ self.eta = p.eta or opts.eta_ancestral
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
@@ -283,6 +281,25 @@ class KDiffusionSampler:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name)
+ if 'eta' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['eta'] = self.eta
+
+ return extra_params_kwargs
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ steps, t_enc = setup_img2img_steps(p, steps)
+
+ sigmas = self.model_wrap.get_sigmas(steps)
+
+ noise = noise * sigmas[steps - t_enc - 1]
+ xi = x + noise
+
+ extra_params_kwargs = self.initialize(p)
+
+ sigma_sched = sigmas[steps - t_enc - 1:]
+
+ self.model_wrap_cfg.init_latent = x
+
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
@@ -291,19 +308,7 @@ class KDiffusionSampler:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0]
- self.model_wrap_cfg.step = 0
- self.sampler_noise_index = 0
-
- if hasattr(k_diffusion.sampling, 'trange'):
- k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
-
- if self.sampler_noises is not None:
- k_diffusion.sampling.torch = TorchHijack(self)
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
+ extra_params_kwargs = self.initialize(p)
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
diff --git a/modules/shared.py b/modules/shared.py
index 2502fe2d..f88c2b02 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -143,6 +143,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -180,7 +181,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
- "save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
}))
options_templates.update(options_section(('system', "System"), {
@@ -190,12 +190,13 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
@@ -221,8 +222,9 @@ options_templates.update(options_section(('ui', "User interface"), {
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "eta": OptionInfo(0.0, "DDIM and K Ancestral eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}),
+ "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
diff --git a/modules/ui.py b/modules/ui.py
index 87a86a45..008bc40d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -9,10 +9,12 @@ import random
import sys
import time
import traceback
+import platform
+import subprocess as sp
import numpy as np
import torch
-from PIL import Image
+from PIL import Image, PngImagePlugin
import gradio as gr
import gradio.utils
@@ -22,6 +24,7 @@ from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
import modules.gfpgan_model
@@ -61,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
-
+folder_symbol = '\uD83D\uDCC2'
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -102,6 +105,7 @@ def save_files(js_data, images, index):
setattr(self, key, value)
data = json.loads(js_data)
+
p = MyObject(data)
path = opts.outdir_save
save_to_dirs = opts.save_to_dirs
@@ -111,10 +115,14 @@ def save_files(js_data, images, index):
path = os.path.join(opts.outdir_save, dirname)
os.makedirs(path, exist_ok=True)
-
- if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
+
+ if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
images = [images[index]]
- data["seed"] += (index - 1 if opts.return_grid else index)
+ infotexts = [data["infotexts"][index]]
+ else:
+ infotexts = data["infotexts"]
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
@@ -137,8 +145,11 @@ def save_files(js_data, images, index):
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
- with open(filepath, "wb") as imgfile:
- imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
+ pnginfo = PngImagePlugin.PngInfo()
+ pnginfo.add_text('parameters', infotexts[i])
+
+ image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
+ image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
filenames.append(filename)
@@ -350,6 +361,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component]
)
+def update_token_counter(text):
+ tokens, token_count, max_length = model_hijack.tokenize(text)
+ style_class = ' class="red"' if (token_count > max_length) else ""
+ return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
@@ -359,11 +374,14 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@@ -470,6 +488,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
send_to_img2img = gr.Button('Send to img2img')
send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras')
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
+ open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group():
html_info = gr.HTML()
@@ -646,6 +666,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras')
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
+ open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group():
html_info = gr.HTML()
@@ -818,6 +840,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
html_info = gr.HTML()
extras_send_to_img2img = gr.Button('Send to img2img')
extras_send_to_inpaint = gr.Button('Send to inpaint')
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
+ open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
fn=run_extras,
@@ -878,32 +902,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
- gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
+ gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
- ckpt_name_list = sorted([x.model_name for x in modules.sd_models.checkpoints_list.values()])
- primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
- interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method")
- submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
+ interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
+ save_as_half = gr.Checkbox(value=False, label="Safe as float16")
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- submit.click(
- fn=run_modelmerger,
- inputs=[
- primary_model_name,
- secondary_model_name,
- interp_method,
- interp_amount
- ],
- outputs=[
- submit_result,
- ]
- )
-
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -927,6 +939,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
return comp(label=info.label, value=fun, **(args or {}))
components = []
+ component_dict = {}
+
+ def open_folder(f):
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ else:
+ sp.Popen(["xdg-open", path])
def run_settings(*args):
changed = 0
@@ -982,7 +1005,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- components.append(create_setting_component(k))
+ component = create_setting_component(k)
+ component_dict[k] = component
+ components.append(component)
items_displayed += 1
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
@@ -1032,7 +1057,34 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
inputs=components,
outputs=[result, text_settings],
)
-
+
+ def modelmerger(*args):
+ try:
+ results = run_modelmerger(*args)
+ except Exception as e:
+ print("Error loading/saving model file:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ modules.sd_models.list_models() #To remove the potentially missing models from the list
+ return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
+ return results
+
+ modelmerger_merge.click(
+ fn=modelmerger,
+ inputs=[
+ primary_model_name,
+ secondary_model_name,
+ interp_method,
+ interp_amount,
+ save_as_half,
+ custom_name,
+ ],
+ outputs=[
+ submit_result,
+ primary_model_name,
+ secondary_model_name,
+ component_dict['sd_model_checkpoint'],
+ ]
+ )
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
@@ -1071,6 +1123,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[extras_image],
)
+ open_txt2img_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
+ inputs=[],
+ outputs=[],
+ )
+
+ open_img2img_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
+ inputs=[],
+ outputs=[],
+ )
+
+ open_extras_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
+ inputs=[],
+ outputs=[],
+ )
+
img2img_send_to_extras.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_extras",
diff --git a/requirements.txt b/requirements.txt
index 08935506..0d9929ca 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,6 @@ font-roboto
gfpgan
gradio
invisible-watermark
-git+https://github.com/crowsonkb/k-diffusion.git
numpy
omegaconf
piexif
@@ -16,5 +15,12 @@ realesrgan
scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12
-transformers
+transformers==4.19.2
torch
+einops
+jsonmerge
+clean-fid
+git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
+resize-right
+torchdiffeq
+kornia
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 505498e7..1e8006e0 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -14,4 +14,11 @@ fonts
font-roboto
timm==0.6.7
fairscale==0.4.9
-piexif==1.1.3 \ No newline at end of file
+piexif==1.1.3
+einops==0.4.1
+jsonmerge==1.8.0
+clean-fid==0.1.29
+git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
+resize-right==0.0.2
+torchdiffeq==0.2.3
+kornia==0.6.7
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 7c01231f..24fa5a0a 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -87,12 +87,12 @@ axis_options = [
AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
- AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
- AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
- AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
- AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
- AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label),
- AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones
+ AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
+ AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
+ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
+ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
+ AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
+ AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
diff --git a/style.css b/style.css
index 4054e2df..9709c4ee 100644
--- a/style.css
+++ b/style.css
@@ -1,5 +1,11 @@
.output-html p {margin: 0 0.5em;}
+.row > *,
+.row > .gr-form > * {
+ min-width: min(120px, 100%);
+ flex: 1 1 0%;
+}
+
.performance {
font-size: 0.85em;
color: #444;
@@ -43,13 +49,17 @@
margin-right: auto;
}
-#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{
+#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto;
flex-grow: 0;
padding-left: 0.25em;
padding-right: 0.25em;
}
+#hidden_element{
+ display: none;
+}
+
#seed_row, #subseed_row{
gap: 0.5rem;
}
@@ -389,3 +399,7 @@ input[type="range"]{
border-radius: 8px;
display: none;
}
+
+.red {
+ color: red;
+}
diff --git a/webui.py b/webui.py
index c70a11c7..39f9ae9a 100644
--- a/webui.py
+++ b/webui.py
@@ -1,6 +1,7 @@
import os
import threading
+from modules import devices
from modules.paths import script_path
import signal
@@ -47,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func):
def f(*args, **kwargs):
+ devices.torch_gc()
+
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
@@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = ""
shared.state.job_count = 0
+ devices.torch_gc()
+
return res
return modules.ui.wrap_gradio_call(f)