aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py71
-rw-r--r--modules/hashes.py4
-rw-r--r--modules/hypernetworks/hypernetwork.py2
-rw-r--r--modules/images.py23
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/mac_specific.py53
-rw-r--r--modules/modelloader.py3
-rw-r--r--modules/processing.py7
-rw-r--r--modules/sd_disable_initialization.py17
-rw-r--r--modules/sd_models.py16
-rw-r--r--modules/sd_samplers_common.py16
-rw-r--r--modules/sd_samplers_kdiffusion.py53
-rw-r--r--modules/shared.py4
-rw-r--r--modules/ui.py16
-rw-r--r--modules/ui_extra_networks.py7
15 files changed, 173 insertions, 122 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 655ca1d3..52c3e7cd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,21 +1,17 @@
-import sys, os, shlex
+import sys
import contextlib
import torch
from modules import errors
-from packaging import version
+
+if sys.platform == "darwin":
+ from modules import mac_specific
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
def has_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
+ if sys.platform != "darwin":
return False
-
+ else:
+ return mac_specific.has_mps
def extract_device_id(args, name):
for x in range(len(args)):
@@ -154,56 +150,3 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-orig_tensor_to = torch.Tensor.to
-def tensor_to_fix(self, *args, **kwargs):
- if self.device.type != 'mps' and \
- ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
- (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
- self = self.contiguous()
- return orig_tensor_to(self, *args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
-orig_layer_norm = torch.nn.functional.layer_norm
-def layer_norm_fix(*args, **kwargs):
- if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
- args = list(args)
- args[0] = args[0].contiguous()
- return orig_layer_norm(*args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
-orig_tensor_numpy = torch.Tensor.numpy
-def numpy_fix(self, *args, **kwargs):
- if self.requires_grad:
- self = self.detach()
- return orig_tensor_numpy(self, *args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-orig_cumsum = torch.cumsum
-orig_Tensor_cumsum = torch.Tensor.cumsum
-def cumsum_fix(input, cumsum_func, *args, **kwargs):
- if input.device.type == 'mps':
- output_dtype = kwargs.get('dtype', input.dtype)
- if output_dtype == torch.int64:
- return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
- return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
- return cumsum_func(input, *args, **kwargs)
-
-
-if has_mps():
- if version.parse(torch.__version__) < version.parse("1.13"):
- # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
- torch.Tensor.numpy = numpy_fix
- elif version.parse(torch.__version__) > version.parse("1.13.1"):
- cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
- torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
- torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
diff --git a/modules/hashes.py b/modules/hashes.py
index 819362a3..83272a07 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -4,6 +4,7 @@ import os.path
import filelock
+from modules import shared
from modules.paths import data_path
@@ -68,6 +69,9 @@ def sha256(filename, title):
if sha256_value is not None:
return sha256_value
+ if shared.cmd_opts.no_hashing:
+ return None
+
print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 503534e2..825a93b2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -307,7 +307,7 @@ class Hypernetwork:
def shorthash(self):
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
- return sha256[0:10]
+ return sha256[0:10] if sha256 else None
def list_hypernetworks(path):
diff --git a/modules/images.py b/modules/images.py
index ae3cdaf4..c2ca8849 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -16,6 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
import json
+import hashlib
from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
@@ -130,7 +131,7 @@ class GridAnnotation:
self.size = None
-def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
+def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -194,32 +195,35 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
- ver_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
- result.paste(im, (pad_left, pad_top))
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+
+ for row in range(rows):
+ for col in range(cols):
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
d = ImageDraw.Draw(result)
for col in range(cols):
- x = pad_left + width * col + width / 2
+ x = pad_left + (width + margin) * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows):
x = pad_left / 2
- y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result
-def draw_prompt_matrix(im, width, height, all_prompts):
+def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
prompts = all_prompts[1:]
boundary = math.ceil(len(prompts) / 2)
@@ -229,7 +233,7 @@ def draw_prompt_matrix(im, width, height, all_prompts):
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
- return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
def resize_image(resize_mode, im, width, height, upscaler_name=None):
@@ -340,6 +344,7 @@ class FilenameGenerator:
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+ 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
'prompt': lambda self: sanitize_filename_part(self.prompt),
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
diff --git a/modules/img2img.py b/modules/img2img.py
index f813299c..bcc158dc 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -142,6 +142,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
+ image_cfg_scale=image_cfg_scale,
inpaint_full_res=inpaint_full_res,
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
new file mode 100644
index 00000000..ddcea53b
--- /dev/null
+++ b/modules/mac_specific.py
@@ -0,0 +1,53 @@
+import torch
+from modules import paths
+from modules.sd_hijack_utils import CondFunc
+from packaging import version
+
+
+# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
+# check `getattr` and try it for compatibility
+def check_for_mps() -> bool:
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+has_mps = check_for_mps()
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
+def cumsum_fix(input, cumsum_func, *args, **kwargs):
+ if input.device.type == 'mps':
+ output_dtype = kwargs.get('dtype', input.dtype)
+ if output_dtype == torch.int64:
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+ elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
+ return cumsum_func(input, *args, **kwargs)
+
+
+if has_mps:
+ # MPS fix for randn in torchsde
+ CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
+
+ if version.parse(torch.__version__) < version.parse("1.13"):
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
+ cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
+ cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
+
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e9aa514e..fc3f6249 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file
if os.path.isdir(full_path):
continue
+ if os.path.islink(full_path) and not os.path.exists(full_path):
+ print(f"Skipping broken symlink: {full_path}")
+ continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0:
diff --git a/modules/processing.py b/modules/processing.py
index 49b1b4ea..61e6dfcd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -186,7 +186,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
return conditioning_image
@@ -268,6 +268,7 @@ class Processed:
self.height = p.height
self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.steps = p.steps
self.batch_size = p.batch_size
self.restore_faces = p.restore_faces
@@ -445,6 +446,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Steps": p.steps,
"Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
@@ -961,12 +963,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength
+ self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.init_latent = None
self.image_mask = mask
self.latent_mask = None
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index e90aa9fe..c4a09d15 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -20,8 +20,9 @@ class DisableInitialization:
```
"""
- def __init__(self):
+ def __init__(self, disable_clip=True):
self.replaced = []
+ self.disable_clip = disable_clip
def replace(self, obj, field, func):
original = getattr(obj, field, None)
@@ -75,12 +76,14 @@ class DisableInitialization:
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
+
+ if self.disable_clip:
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 300387a9..d847d358 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -59,13 +59,17 @@ class CheckpointInfo:
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+ if self.sha256 is None:
+ return
+
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
- self.ids += [self.shorthash, self.sha256]
- self.register()
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
+ checkpoints_list.pop(self.title)
self.title = f'{self.name} [{self.shorthash}]'
+ self.register()
return self.shorthash
@@ -158,7 +162,7 @@ def select_checkpoint():
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+ print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@@ -350,6 +354,9 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_fp16 = True
+sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
+sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -370,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
timer.record("find config")
@@ -382,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
sd_model = None
try:
- with sd_disable_initialization.DisableInitialization():
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
pass
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 3c03d442..a1aac7cf 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,6 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-import torchsde._brownian.brownian_interval
from modules import devices, processing, images, sd_vae_approx
from modules.shared import opts, state
@@ -61,18 +60,3 @@ def store_latent(decoded):
class InterruptedException(BaseException):
pass
-
-
-# MPS fix for randn in torchsde
-# XXX move this to separate file for MPS
-def torchsde_randn(size, dtype, device, seed):
- if device.type == 'mps':
- generator = torch.Generator(devices.cpu).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
- else:
- generator = torch.Generator(device).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=device, generator=generator)
-
-
-torchsde._brownian.brownian_interval._randn = torchsde_randn
-
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index aa7f106b..f076fc55 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,6 +1,7 @@
from collections import deque
import torch
import inspect
+import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -56,6 +57,7 @@ class CFGDenoiser(torch.nn.Module):
self.nmask = None
self.init_latent = None
self.step = 0
+ self.image_cfg_scale = None
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -67,19 +69,36 @@ class CFGDenoiser(torch.nn.Module):
return denoised
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+
+ return denoised
+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+ assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
@@ -88,7 +107,10 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
- cond_in = torch.cat([tensor, uncond])
+ if not is_edit_model:
+ cond_in = torch.cat([tensor, uncond])
+ else:
+ cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
@@ -104,7 +126,13 @@ class CFGDenoiser(torch.nn.Module):
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
+
+ if not is_edit_model:
+ c_crossattn = [tensor[a:b]]
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
@@ -115,7 +143,10 @@ class CFGDenoiser(torch.nn.Module):
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ if not is_edit_model:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ else:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -198,6 +229,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -260,13 +292,14 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ 'cond_scale': p.cfg_scale,
+ }
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
diff --git a/modules/shared.py b/modules/shared.py
index 69634fd8..79fbf724 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -106,7 +106,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
-
+parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
script_loading.preload_extensions(extensions.extensions_dir, parser)
@@ -327,7 +327,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
- "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
diff --git a/modules/ui.py b/modules/ui.py
index dca67fed..ffcf2769 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -479,8 +479,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
if opts.dimensions_and_batch_together:
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
@@ -775,15 +775,17 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
if opts.dimensions_and_batch_together:
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
elif category == "cfg":
with FormGroup():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ with FormRow():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed":
@@ -879,6 +881,7 @@ def create_ui():
batch_count,
batch_size,
cfg_scale,
+ image_cfg_scale,
denoising_strength,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
@@ -965,6 +968,7 @@ def create_ui():
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
+ (image_cfg_scale, "Image CFG scale"),
(seed, "Seed"),
(width, "Size-1"),
(height, "Size-2"),
@@ -1609,6 +1613,12 @@ def create_ui():
outputs=[component, text_settings],
)
+ text_settings.change(
+ fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
+ inputs=[],
+ outputs=[image_cfg_scale],
+ )
+
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 83367968..90abec0a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -26,11 +26,12 @@ def add_pages_to_demo(app):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
- if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
+ if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
- if os.path.splitext(filename)[1].lower() != ".png":
- raise ValueError(f"File cannot be fetched: {filename}. Only png.")
+ ext = os.path.splitext(filename)[1].lower()
+ if ext not in (".png", ".jpg"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})