aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py8
-rw-r--r--modules/api/models.py8
-rw-r--r--modules/devices.py6
-rw-r--r--modules/esrgan_model.py2
-rw-r--r--modules/hypernetworks/hypernetwork.py49
-rw-r--r--modules/hypernetworks/ui.py4
-rw-r--r--modules/images.py65
-rw-r--r--modules/img2img.py8
-rw-r--r--modules/processing.py25
-rw-r--r--modules/script_callbacks.py97
-rw-r--r--modules/scunet_model.py3
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/shared.py10
-rw-r--r--modules/swinir_model.py12
-rw-r--r--modules/textual_inversion/autocrop.py341
-rw-r--r--modules/textual_inversion/preprocess.py38
-rw-r--r--modules/textual_inversion/textual_inversion.py22
-rw-r--r--modules/ui.py22
18 files changed, 637 insertions, 87 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 67b783de..ca289d9f 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -54,7 +54,7 @@ class Api:
b64images = list(map(encode_pil_to_base64, processed.images))
- return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info)
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -93,8 +93,12 @@ class Api:
processed = process_images(p)
b64images = list(map(encode_pil_to_base64, processed.images))
+
+ if (not img2imgreq.include_init_images):
+ img2imgreq.init_images = None
+ img2imgreq.mask = None
- return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info)
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
diff --git a/modules/api/models.py b/modules/api/models.py
index 66d15765..00406368 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str
field_type: Any
field_value: Any
+ field_exclude: bool = False
class PydanticModelGenerator:
@@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
- field_value=fields["default"]))
+ field_value=fields["default"],
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self):
"""
@@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization
"""
fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
@@ -102,7 +104,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
class TextToImageResponse(BaseModel):
diff --git a/modules/devices.py b/modules/devices.py
index dc1f3cdd..7511e1dc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -45,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
+device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
@@ -81,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a49e2258..a13cf6ac 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index d647ea55..8113b35b 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -5,6 +5,7 @@ import html
import os
import sys
import traceback
+import inspect
import modules.textual_inversion.dataset
import torch
@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
+from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque
from statistics import stdev, mean
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
+ "tanh": torch.nn.Tanh,
+ "sigmoid": torch.nn.Sigmoid,
}
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
else:
for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
- layer.weight.data.normal_(mean=0.0, std=0.01)
- layer.bias.data.zero_()
-
+ w, b = layer.weight.data, layer.bias.data
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
+ normal_(w, mean=0.0, std=0.01)
+ normal_(b, mean=0.0, std=0.005)
+ elif weight_init == 'XavierUniform':
+ xavier_uniform_(w)
+ zeros_(b)
+ elif weight_init == 'XavierNormal':
+ xavier_normal_(w)
+ zeros_(b)
+ elif weight_init == 'KaimingUniform':
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ elif weight_init == 'KaimingNormal':
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ else:
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
self.to(devices.device)
def fix_old_state_dict(self, state_dict):
@@ -105,7 +126,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -114,13 +135,14 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.activation_func = activation_func
+ self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -144,6 +166,7 @@ class Hypernetwork:
state_dict['layer_structure'] = self.layer_structure
state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['weight_initialization'] = self.weight_init
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -158,15 +181,21 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ print(self.layer_structure)
self.activation_func = state_dict.get('activation_func', None)
+ print(f"Activation function is {self.activation_func}")
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
+ print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
+ print(f"Dropout usage is set to {self.use_dropout}" )
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
@@ -458,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 2b472d87..2c6c0470 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
+keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
+ weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)
diff --git a/modules/images.py b/modules/images.py
index 848ede75..7870b5b7 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -16,7 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
-from modules import sd_samplers, shared
+from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
-re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
+re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
@@ -343,8 +343,11 @@ class FilenameGenerator:
def datetime(self, *args):
time_datetime = datetime.datetime.now()
- time_format = args[0] if len(args) > 0 else self.default_time_format
- time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+ try:
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ except pytz.exceptions.UnknownTimeZoneError as _:
+ time_zone = None
time_zone_time = time_datetime.astimezone(time_zone)
try:
@@ -359,9 +362,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x):
text, pattern = m.groups()
+ res += text
if pattern is None:
- res += text
continue
pattern_args = []
@@ -382,12 +385,9 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- if replacement is None:
- res += f'[{pattern}]'
- else:
+ if replacement is not None:
res += str(replacement)
-
- continue
+ continue
res += f'[{pattern}]'
@@ -451,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
namegen = FilenameGenerator(p, seed, prompt)
- if extension == 'png' and opts.enable_pnginfo and info is not None:
- pnginfo = PngImagePlugin.PngInfo()
-
- if existing_info is not None:
- for k, v in existing_info.items():
- pnginfo.add_text(k, str(v))
-
- pnginfo.add_text(pnginfo_section_name, info)
- else:
- pnginfo = None
-
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
@@ -474,8 +463,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if forced_filename is None:
if short_filename or seed is None:
file_decoration = ""
- else:
+ elif opts.save_to_dirs:
file_decoration = opts.samples_filename_pattern or "[seed]"
+ else:
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
add_number = opts.save_images_add_number or file_decoration == ''
@@ -487,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if add_number:
basecount = get_next_sequence_number(path, basename)
fullfn = None
- fullfn_without_extension = None
for i in range(500):
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn):
break
else:
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, file_decoration)
else:
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
- fullfn_without_extension = os.path.join(path, forced_filename)
+
+ pnginfo = existing_info or {}
+ if info is not None:
+ pnginfo[pnginfo_section_name] = info
+
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
+ script_callbacks.before_image_saved_callback(params)
+
+ image = params.image
+ fullfn = params.filename
+ info = params.pnginfo.get(pnginfo_section_name, None)
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
def exif_bytes():
return piexif.dump({
@@ -508,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
},
})
- if extension.lower() in ("jpg", "jpeg", "webp"):
+ if extension.lower() == '.png':
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
+
+ image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
image.save(fullfn, quality=opts.jpeg_quality)
+
if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn)
else:
- image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
+ image.save(fullfn, quality=opts.jpeg_quality)
target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length
@@ -536,6 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
txt_fullfn = None
+ script_callbacks.image_saved_callback(params)
+
return fullfn, txt_fullfn
diff --git a/modules/img2img.py b/modules/img2img.py
index 8d9f7cf9..9c0cf23e 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
break
img = Image.open(image)
+ # Use the EXIF orientation of photos taken by smartphones.
+ img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args)
@@ -61,19 +63,25 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
is_batch = mode == 2
if is_inpaint:
+ # Drawn mask
if mask_mode == 0:
image = init_img_with_mask['image']
mask = init_img_with_mask['mask']
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert('RGB')
+ # Uploaded mask
else:
image = init_img_inpaint
mask = init_mask_inpaint
+ # No mask
else:
image = init_img
mask = None
+ # Use the EXIF orientation of photos taken by smartphones.
+ image = ImageOps.exif_transpose(image)
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
diff --git a/modules/processing.py b/modules/processing.py
index c61bbfbd..4efba946 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -77,9 +77,8 @@ def get_correct_sampler(p):
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
-
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -109,13 +108,14 @@ class StableDiffusionProcessing():
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
- self.denoising_strength: float = 0
+ self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
- self.ddim_discretize = opts.ddim_discretize
+ self.ddim_discretize = ddim_discretize or opts.ddim_discretize
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
if not seed_enable_extras:
self.subseed = -1
@@ -129,7 +129,6 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
-
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -351,6 +350,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing) -> Processed:
+ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+
+ try:
+ for k, v in p.override_settings.items():
+ opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+
+ res = process_images_inner(p)
+
+ finally:
+ for k, v in stored_opts.items():
+ opts.data[k] = v
+
+ return res
+
+
+def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list:
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index f46d3d9a..6ea58d61 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,37 +1,100 @@
+import sys
+import traceback
+from collections import namedtuple
+import inspect
+
+def report_exception(c, job):
+ print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+class ImageSaveParams:
+ def __init__(self, image, p, filename, pnginfo):
+ self.image = image
+ """the PIL image itself"""
+
+ self.p = p
+ """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
+
+ self.filename = filename
+ """name of file that the image would be saved to"""
+
+ self.pnginfo = pnginfo
+ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
+callbacks_before_image_saved = []
+callbacks_image_saved = []
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
+ callbacks_ui_settings.clear()
+ callbacks_before_image_saved.clear()
+ callbacks_image_saved.clear()
def model_loaded_callback(sd_model):
- for callback in callbacks_model_loaded:
- callback(sd_model)
+ for c in callbacks_model_loaded:
+ try:
+ c.callback(sd_model)
+ except Exception:
+ report_exception(c, 'model_loaded_callback')
def ui_tabs_callback():
res = []
- for callback in callbacks_ui_tabs:
- res += callback() or []
+ for c in callbacks_ui_tabs:
+ try:
+ res += c.callback() or []
+ except Exception:
+ report_exception(c, 'ui_tabs_callback')
return res
def ui_settings_callback():
- for callback in callbacks_ui_settings:
- callback()
+ for c in callbacks_ui_settings:
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'ui_settings_callback')
+
+
+def before_image_saved_callback(params: ImageSaveParams):
+ for c in callbacks_image_saved:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'before_image_saved_callback')
+
+
+def image_saved_callback(params: ImageSaveParams):
+ for c in callbacks_image_saved:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'image_saved_callback')
+
+
+def add_callback(callbacks, fun):
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+
+ callbacks.append(ScriptCallback(filename, fun))
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- callbacks_model_loaded.append(callback)
+ add_callback(callbacks_model_loaded, callback)
def on_ui_tabs(callback):
@@ -44,10 +107,26 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- callbacks_ui_tabs.append(callback)
+ add_callback(callbacks_ui_tabs, callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- callbacks_ui_settings.append(callback)
+ add_callback(callbacks_ui_settings, callback)
+
+
+def on_before_image_saved(callback):
+ """register a function to be called before an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
+ """
+ add_callback(callbacks_before_image_saved, callback)
+
+
+def on_image_saved(callback):
+ """register a function to be called after an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
+ """
+ add_callback(callbacks_image_saved, callback)
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 36a996bf..59532274 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), device)
- img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 0b408a70..3670b57d 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -228,7 +228,7 @@ class VanillaStableDiffusionSampler:
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -429,7 +429,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
diff --git a/modules/shared.py b/modules/shared.py
index 76cbb1bd..1a9b8289 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -58,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -84,7 +84,7 @@ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load mod
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
cmd_opts = parser.parse_args()
-restricted_opts = [
+restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
"outdir_samples",
@@ -94,10 +94,10 @@ restricted_opts = [
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
-]
+}
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index baa02e3d..4253b66d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -7,8 +7,8 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
+from modules import modelloader, devices
+from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file)
if model is None:
return img
- model = model.to(device)
+ model = model.to(devices.device_swinir)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
@@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
+ W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
new file mode 100644
index 00000000..9859974a
--- /dev/null
+++ b/modules/textual_inversion/autocrop.py
@@ -0,0 +1,341 @@
+import cv2
+import requests
+import os
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
+
+ focus = focal_point(im_debug, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
+
+ pois = []
+
+ weight_pref_total = 0
+ if len(corner_points) > 0:
+ weight_pref_total += settings.corner_points_weight
+ if len(entropy_points) > 0:
+ weight_pref_total += settings.entropy_points_weight
+ if len(face_points) > 0:
+ weight_pref_total += settings.face_points_weight
+
+ corner_centroid = None
+ if len(corner_points) > 0:
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
+
+ entropy_centroid = None
+ if len(entropy_points) > 0:
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
+
+ face_centroid = None
+ if len(face_points) > 0:
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
+
+ average_point = poi_average(pois, settings)
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ if settings.dnn_model_path is not None:
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.9, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
+ size = w,
+ weight = 1/len(faces[1])
+ )
+ )
+ return results
+ else:
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
+ return []
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+def centroid(pois):
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
+
+def poi_average(pois, settings):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
+ avg_x = round(x / weight)
+ avg_y = round(y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+ return w > h
+
+
+def is_portrait(w, h):
+ return h > w
+
+
+def is_square(w, h):
+ return w == h
+
+
+def download_and_cache_models(dirname):
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
+
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
+
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path \ No newline at end of file
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 33eaddb6..e13b1894 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,12 +7,14 @@ import tqdm
import time
from modules import shared, images
+from modules.paths import models_path
from modules.shared import opts, cmd_opts
+from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
@@ -22,7 +24,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally:
@@ -34,7 +36,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -113,6 +115,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
+
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
@@ -137,11 +140,36 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
ratio = (img.height * width) / (img.width * height)
inverse_xy = True
+ process_default_resize = True
+
if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
- else:
+ process_default_resize = False
+
+ if process_focal_crop and img.height != img.width:
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
+ )
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, existing_caption=existing_caption)
+ process_default_resize = False
+
+ if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption)
- shared.state.nextjob()
+ shared.state.nextjob() \ No newline at end of file
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 529ed3e2..4fcebe74 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
-from modules import shared, devices, sd_hijack, processing, sd_models
+from modules import shared, devices, sd_hijack, processing, sd_models, images
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
@@ -164,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -244,6 +249,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = "<none>"
last_saved_image = "<none>"
+ forced_filename = "<none>"
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
@@ -283,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{embedding.step}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -293,8 +301,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
-
+ forced_filename = f'{embedding_name}-{embedding.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@@ -350,8 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
- image.save(last_saved_image)
-
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
@@ -371,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
+ # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
+ embedding.name = embedding_name
+ filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename
diff --git a/modules/ui.py b/modules/ui.py
index 03528968..0a63e357 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
@@ -1260,6 +1261,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images')
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
@@ -1267,6 +1269,12 @@ def create_ui(wrap_gradio_gpu_call):
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+ with gr.Row(visible=False) as process_focal_crop_row:
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image')
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1280,6 +1288,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[process_split_extra_row],
)
+ process_focal_crop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_focal_crop],
+ outputs=[process_focal_crop_row],
+ )
+
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
@@ -1342,6 +1356,7 @@ def create_ui(wrap_gradio_gpu_call):
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_activation_func,
+ new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout
],
@@ -1367,6 +1382,11 @@ def create_ui(wrap_gradio_gpu_call):
process_caption_deepbooru,
process_split_threshold,
process_overlap_ratio,
+ process_focal_crop,
+ process_focal_crop_face_weight,
+ process_focal_crop_entropy_weight,
+ process_focal_crop_edges_weight,
+ process_focal_crop_debug,
],
outputs=[
ti_output,