aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py48
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/generation_parameters_copypaste.py17
-rw-r--r--modules/hypernetworks/hypernetwork.py9
-rw-r--r--modules/processing.py30
-rw-r--r--modules/sd_hijack.py7
-rw-r--r--modules/sd_hijack_clip.py4
-rw-r--r--modules/sd_vae.py20
-rw-r--r--modules/shared.py2
-rw-r--r--modules/sub_quadratic_attention.py15
-rw-r--r--modules/textual_inversion/dataset.py10
-rw-r--r--modules/textual_inversion/textual_inversion.py202
-rw-r--r--modules/ui.py37
13 files changed, 285 insertions, 120 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 1c121ff0..6c564ad8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
@@ -28,8 +28,13 @@ def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+def script_name_to_index(name, scripts):
+ try:
+ return [script.title().lower() for script in scripts].index(name.lower())
+ except:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@@ -144,7 +149,21 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+ def get_script(self, script_name, script_runner):
+ if script_name is None:
+ return None, None
+
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+
+ script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
+ script = script_runner.selectable_scripts[script_idx]
+ return script, script_idx
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
@@ -154,14 +173,22 @@ class Api:
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
+ args = vars(populate)
+ args.pop('script_name', None)
+
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
-
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -171,6 +198,8 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
+ script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
+
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
@@ -187,13 +216,20 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ args.pop('script_name', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/api/models.py b/modules/api/models.py
index 49bf1e7a..880edde6 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -100,13 +100,13 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
class TextToImageResponse(BaseModel):
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 12a9de3d..f7f68b67 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res):
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
+ if shared.opts.use_old_hires_fix_width_height:
+ hires_width = int(res.get("Hires resize-1", None))
+ hires_height = int(res.get("Hires resize-2", None))
+
+ if hires_width is not None and hires_height is not None:
+ res['Size-1'] = hires_width
+ res['Size-2'] = hires_height
+ return
+
if firstpass_width is None or firstpass_height is None:
return
@@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
- # old algorithm for auto-calculating first pass size
- desired_pixel_count = 512 * 512
- actual_pixel_count = width * height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- firstpass_width = math.ceil(scale * width / 64) * 64
- firstpass_height = math.ceil(scale * height / 64) * 64
+ from modules import processing
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b0cfbe71..ea3f1db9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,6 +24,7 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -456,7 +459,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
diff --git a/modules/processing.py b/modules/processing.py
index 82157bc9..f04a0e1e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -98,7 +98,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -149,7 +149,7 @@ class StableDiffusionProcessing():
self.seed_resize_from_w = 0
self.scripts = None
- self.script_args = None
+ self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
@@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
return res
+def old_hires_fix_first_pass_dimensions(width, height):
+ """old algorithm for auto-calculating first pass size"""
+
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ width = math.ceil(scale * width / 64) * 64
+ height = math.ceil(scale * height / 64) * 64
+
+ return width, height
+
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
@@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_upscale_to_y = hr_resize_y
if firstphase_width != 0 or firstphase_height != 0:
- print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
- self.hr_scale = self.width / firstphase_width
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
self.width = firstphase_width
self.height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
+ self.applied_old_hires_behavior_to = None
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
+ self.hr_resize_x = self.width
+ self.hr_resize_y = self.height
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
+
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
+ self.applied_old_hires_behavior_to = (self.width, self.height)
+
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
self.extra_generation_params["Hires upscale"] = self.hr_scale
self.hr_upscale_to_x = int(self.width * self.hr_scale)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index cfdb09d6..6b0d95af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
- def hijack(self, m):
+ def __init__(self):
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
-
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5520c9b2..852afc66 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= original_mean / new_mean
+ z = z * (original_mean / new_mean)
return z
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index ac71d62d..0a49daa1 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,9 @@
import torch
+import safetensors.torch
import os
import collections
from collections import namedtuple
-from modules import shared, devices, script_callbacks
+from modules import shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob
from copy import deepcopy
@@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
@@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print(f"Using VAE found similar to selected model: {vae_file}")
+ # if still not found, try look for ".vae.safetensors" beside model
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.safetensors"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
@@ -163,8 +172,9 @@ def load_vae(model, vae_file=None):
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+
+ vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
@@ -195,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
+
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
+
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
diff --git a/modules/shared.py b/modules/shared.py
index a6712dae..aa37c8ce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
@@ -398,6 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index fea7aaac..55052815 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,7 +15,8 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-from typing import Optional, NamedTuple, Protocol, List
+from typing import Optional, NamedTuple, List
+
def narrow_trunc(
input: Tensor,
@@ -25,12 +26,14 @@ def narrow_trunc(
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
-class SummarizeChunk(Protocol):
+
+class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
@@ -38,7 +41,8 @@ class SummarizeChunk(Protocol):
value: Tensor,
) -> AttnChunk: ...
-class ComputeQueryChunkAttn(Protocol):
+
+class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
@@ -46,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor,
) -> Tensor: ...
+
def _summarize_chunk(
query: Tensor,
key: Tensor,
@@ -66,6 +71,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
def _query_chunk_attention(
query: Tensor,
key: Tensor,
@@ -106,6 +112,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
@@ -125,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
+
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
+
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 88d68c76..fa48708e 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -28,13 +28,11 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.width = width
- self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
@@ -47,10 +45,10 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
+ assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
@@ -59,7 +57,9 @@ class PersonalizedBase(Dataset):
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB')
+ if not varsize:
+ image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 45882ed6..5420903f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
import inspect
+from collections import namedtuple
import torch
import tqdm
@@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
- insert_image_data_embed, extract_image_data_embed,
- caption_image_overlay)
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
from modules.textual_inversion.logging import save_settings_to_file
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+
+ for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+
+ return textual_inversion_templates
+
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -66,17 +81,41 @@ class Embedding:
return self.cached_checksum
+class DirWithTextualInversionEmbeddings:
+ def __init__(self, path):
+ self.path = path
+ self.mtime = None
+
+ def has_changed(self):
+ if not os.path.isdir(self.path):
+ return False
+
+ mt = os.path.getmtime(self.path)
+ if self.mtime is None or mt > self.mtime:
+ return True
+
+ def update(self):
+ if not os.path.isdir(self.path):
+ return
+
+ self.mtime = os.path.getmtime(self.path)
+
+
class EmbeddingDatabase:
- def __init__(self, embeddings_dir):
+ def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
- self.dir_mtime = None
- self.embeddings_dir = embeddings_dir
self.expected_shape = -1
+ self.embedding_dirs = {}
- def register_embedding(self, embedding, model):
+ def add_embedding_dir(self, path):
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
+
+ def clear_embedding_dirs(self):
+ self.embedding_dirs.clear()
+ def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenize([embedding.name])[0]
@@ -93,65 +132,62 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
- def load_textual_inversion_embeddings(self, force_reload = False):
- mt = os.path.getmtime(self.embeddings_dir)
- if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
- return
+ def load_from_file(self, path, filename):
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
- self.skipped_embeddings.clear()
- self.expected_shape = self.get_expected_shape()
-
- def process_file(path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name', name)
- else:
- data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- else:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
return
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
-
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
- self.register_embedding(embedding, shared.sd_model)
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
else:
- self.skipped_embeddings[name] = embedding
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ elif ext in ['.BIN', '.PT']:
+ data = torch.load(path, map_location="cpu")
+ else:
+ return
- for root, dirs, fns in os.walk(self.embeddings_dir):
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
+
+ def load_from_dir(self, embdir):
+ if not os.path.isdir(embdir.path):
+ return
+
+ for root, dirs, fns in os.walk(embdir.path):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -159,12 +195,32 @@ class EmbeddingDatabase:
if os.stat(fullfn).st_size == 0:
continue
- process_file(fullfn, fn)
+ self.load_from_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
+ def load_textual_inversion_embeddings(self, force_reload=False):
+ if not force_reload:
+ need_reload = False
+ for path, embdir in self.embedding_dirs.items():
+ if embdir.has_changed():
+ need_reload = True
+ break
+
+ if not need_reload:
+ return
+
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
+
+ for path, embdir in self.embedding_dirs.items():
+ self.load_from_dir(embdir)
+ embdir.update()
+
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
@@ -233,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
@@ -243,22 +299,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert template_file, "Prompt template file is empty"
- assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
- assert steps > 0 , "Max steps must be positive"
+ assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
- assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
- assert create_image_every >= 0 , "Create image must be positive or 0"
+ assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = template_file.path
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
@@ -309,7 +369,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
diff --git a/modules/ui.py b/modules/ui.py
index 99483130..b6079aec 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -37,7 +37,7 @@ from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-import modules.textual_inversion.ui
+from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
@@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
with devices.autocast():
p.init([""], [0], [0])
- return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+ return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
@@ -745,15 +745,20 @@ def create_ui():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- hr_resolution_preview_args = dict(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False
- )
-
for input in hr_resolution_preview_inputs:
- input.change(**hr_resolution_preview_args)
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -1317,6 +1322,9 @@ def create_ui():
outputs=[process_focal_crop_row],
)
+ def get_textual_inversion_template_names():
+ return sorted([x for x in textual_inversion.textual_inversion_templates])
+
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow():
@@ -1340,9 +1348,14 @@ def create_ui():
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
with FormRow():
@@ -1449,6 +1462,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,
@@ -1480,6 +1494,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,