From f64af77adcd20fabe00e1e642512db9c6742ed23 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 23 Jan 2023 22:49:20 -0500 Subject: Fix different first gen with Approx NN previews The loading of the model for approx nn live previews can change the internal state of PyTorch, resulting in a different image. This can be avoided by preloading the approx nn model in advance. --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..3bd590ba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -568,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1: + # preload approx nn model before sampling for a more deterministic result + sd_vae_approx.model() + if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) -- cgit v1.2.1 From e46bfa5a9e9b489ae925a9c23880e34fe8d9fffa Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:24:32 -0500 Subject: handling sub grids and merging into one --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 3b1c5f34..0bc3d524 100644 --- a/modules/images.py +++ b/modules/images.py @@ -195,7 +195,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] - pad_top = max(hor_text_heights) + line_spacing * 2 + pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") result.paste(im, (pad_left, pad_top)) -- cgit v1.2.1 From 28189985e6f56dc725938a3f0e4d2462dad74bc5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:24:27 +0300 Subject: remove fairscale requirement, add fake fairscale to make BLIP not complain about it --- modules/interrogate.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 236e6983..9f063197 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -82,9 +82,16 @@ class InterrogateModels: return self.loaded_categories + def create_fake_fairscale(self): + class FakeFairscale: + def checkpoint_wrapper(self): + pass + + sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale + def load_blip_model(self): - with paths.Prioritize("BLIP"): - import models.blip + create_fake_fairscale() + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), -- cgit v1.2.1 From 5228ec8bdada50a8d614573e980193ca89192361 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:30:43 +0300 Subject: remove fairscale requirement, add fake fairscale to make BLIP not complain about it mk2 --- modules/interrogate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 9f063197..c72ff694 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -90,7 +90,7 @@ class InterrogateModels: sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale def load_blip_model(self): - create_fake_fairscale() + self.create_fake_fairscale() import models.blip files = modelloader.load_models( -- cgit v1.2.1 From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 24 Jan 2023 23:51:45 -0500 Subject: Add option for float32 sampling with float16 UNet This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half(). --- modules/deepbooru_model.py | 4 +++- modules/devices.py | 2 ++ modules/processing.py | 15 ++++++++------- modules/sd_hijack_unet.py | 29 +++++++++++++++++++++++++++++ modules/sd_hijack_utils.py | 28 ++++++++++++++++++++++++++++ modules/sd_models.py | 10 ++++++++++ modules/shared.py | 1 + 7 files changed, 81 insertions(+), 8 deletions(-) create mode 100644 modules/sd_hijack_utils.py (limited to 'modules') diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py index edd40c81..83d2ff09 100644 --- a/modules/deepbooru_model.py +++ b/modules/deepbooru_model.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import devices + # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more @@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): t_358, = inputs t_359 = t_358.permute(*[0, 3, 1, 2]) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) - t_360 = self.n_Conv_0(t_359_padded) + t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded) t_361 = F.relu(t_360) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_362 = self.n_MaxPool_0(t_361) diff --git a/modules/devices.py b/modules/devices.py index 524ec7af..0981ef80 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -79,6 +79,8 @@ cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False def randn(seed, shape): diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..2d186ba0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,8 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -203,7 +204,7 @@ class StableDiffusionProcessing: # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. - conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) + conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), @@ -211,7 +212,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -225,10 +226,10 @@ class StableDiffusionProcessing: # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image) + return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(): + with devices.autocast(disable=devices.unet_needs_upcast): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] @@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch +from packaging import version + +from modules import devices +from modules.sd_hijack_utils import CondFunc class TorchHijackForUnet: @@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): + return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +if version.parse(torch.__version__) <= version.parse("1.13.1"): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py new file mode 100644 index 00000000..f81b169a --- /dev/null +++ b/modules/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-2, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half: vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None model.half() model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True timer = Timer() diff --git a/modules/shared.py b/modules/shared.py index 5f713bee..4ce1209b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") -- cgit v1.2.1 From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 00:23:10 -0500 Subject: Add UI setting for upcasting attention to float32 Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also. --- modules/devices.py | 6 +- modules/processing.py | 2 +- modules/sd_hijack_optimizations.py | 159 +++++++++++++++++++++++-------------- modules/shared.py | 1 + modules/sub_quadratic_attention.py | 4 +- 5 files changed, 108 insertions(+), 64 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 0981ef80..6b36622c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -108,6 +108,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -125,7 +129,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." diff --git a/modules/processing.py b/modules/processing.py index 2d186ba0..a850082d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(disable=devices.unet_needs_upcast): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 74452709..c02d954c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors +from modules import shared, errors, devices from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() - s2 = s1.softmax(dim=-1) - del s1 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - k_in *= self.scale - - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) * self.scale + k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = einsum_op(q, k, v) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- @@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = x.to(dtype) + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out @@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ query_chunk_size = q_tokens kv_chunk_size = k_tokens - return efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) def get_xformers_flash_attention_op(q, k, v): @@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 4ce1209b..6a0b96cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 55052815..05595323 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -67,7 +67,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) return hidden_states_slice -- cgit v1.2.1 From ee0a0da3244123cb6d2ba4097a54a1e9caccb687 Mon Sep 17 00:00:00 2001 From: Kyle Date: Wed, 25 Jan 2023 08:53:23 -0500 Subject: Add instruct-pix2pix hijack Allows loading instruct-pix2pix models via same method as inpainting models in sd_models.py and sd_hijack_ip2p.py Adds ddpm_edit.py necessary for instruct-pix2pix --- modules/models/diffusion/ddpm_edit.py | 1459 +++++++++++++++++++++++++++++++++ modules/sd_hijack_ip2p.py | 13 + modules/sd_models.py | 12 +- 3 files changed, 1483 insertions(+), 1 deletion(-) create mode 100644 modules/models/diffusion/ddpm_edit.py create mode 100644 modules/sd_hijack_ip2p.py (limited to 'modules') diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py new file mode 100644 index 00000000..f3d49c44 --- /dev/null +++ b/modules/models/diffusion/ddpm_edit.py @@ -0,0 +1,1459 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + load_ema=True, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + if self.use_ema and load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + # If initialing from EMA-only checkpoint, create EMA model after loading. + if self.use_ema and not load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + # Our model adds additional channels to the first layer to condition on an input image. + # For the first layer, copy existing channel weights and initialize new channel weights to zero. + input_keys = [ + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ] + + self_sd = self.state_dict() + for input_key in input_keys: + if input_key not in sd or input_key not in self_sd: + continue + + input_weight = self_sd[input_key] + + if input_weight.size() != sd[input_key].size(): + print(f"Manual init: {input_key}") + input_weight.zero_() + input_weight[:, :4, :, :].copy_(sd[input_key]) + ignore_keys.append(input_key) + + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + return batch[k] + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + load_ema=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + if self.use_ema and not load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, uncond=0.05): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key) + if bs is not None: + xc["c_crossattn"] = xc["c_crossattn"][:bs] + xc["c_concat"] = xc["c_concat"][:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + + null_prompt = self.get_learned_conditioning([""]) + cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())] + cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()] + + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=False, **kwargs): + + use_ddim = False + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, uncond=0) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reals"] = xc["c_concat"] + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py new file mode 100644 index 00000000..635f015f --- /dev/null +++ b/modules/sd_hijack_ip2p.py @@ -0,0 +1,13 @@ +import collections +import os.path +import sys +import gc +import time + +def should_hijack_ip2p(checkpoint_info): + from modules import sd_models + + ckpt_basename = os.path.basename(checkpoint_info.filename).lower() + cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + + return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..cddc2343 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting +from modules.sd_hijack_ip2p import should_hijack_ip2p model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -365,6 +366,15 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None + if should_hijack_ip2p(checkpoint_info): + sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.first_stage_key = "edited" + sd_config.model.params.cond_stage_key = "edit" + sd_config.model.params.image_size = 16 + sd_config.model.params.unet_config.params.in_channels = 8 + sd_config.model.params.unet_config.params.out_channels = 4 + if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -429,7 +439,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.1 From 57c1baa774d07060af0abbd2974c5f36c8cb63ac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 18:56:23 +0300 Subject: change to code for live preview fix on OSX to be bit more obvious --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3bd590ba..57c3db1b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -568,8 +568,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1: - # preload approx nn model before sampling for a more deterministic result + # for OSX, loading the model during sampling changes the generated picture, so it is loaded here + if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() if not p.disable_extra_networks: -- cgit v1.2.1 From e179b6098ac1b1ce9645fef5bd9fd0bc9b918f30 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Wed, 25 Jan 2023 08:48:40 -0800 Subject: allow symlinks in the textual inversion embeddings folder --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4e90f690..6cf00e65 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -194,7 +194,7 @@ class EmbeddingDatabase: if not os.path.isdir(embdir.path): return - for root, dirs, fns in os.walk(embdir.path): + for root, dirs, fns in os.walk(embdir.path, followlinks=True): for fn in fns: try: fullfn = os.path.join(root, fn) -- cgit v1.2.1 From 789d47f832a5c921dbbdd0a657dff9bca7f78d94 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 19:55:31 +0300 Subject: make clicking extra networks button one more time close the extra networks UI --- modules/ui_extra_networks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8b4f97f8..c6ff889a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -117,8 +117,13 @@ def create_ui(container, button, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) - button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + def toggle_visibility(is_visible): + is_visible = not is_visible + return is_visible, gr.update(visible=is_visible) + + state_visible = gr.State(value=False) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) + button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) def refresh(): res = [] -- cgit v1.2.1 From 15e89ef0f6f22f823c19592a401b9e4ee477258c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 20:11:01 +0300 Subject: fix for unet hijack breaking the train tab --- modules/sd_hijack_unet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 88c94e54..a6ee577c 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -36,8 +36,11 @@ th = TorchHijackForUnet() # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - for y in cond.keys(): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + + if isinstance(cond, dict): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() -- cgit v1.2.1 From d1d6ce29831d1b067801c3206f314258de88f683 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:25 +0300 Subject: add edit_image_conditioning from my earlier edits in case there's an attempt to inegrate pix2pix properly this allows to use pix2pix model in img2img though it won't work well this way --- modules/processing.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 9e5a2f38..cb41288a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -185,7 +185,12 @@ class StableDiffusionProcessing: conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. return conditioning - def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): + def edit_image_conditioning(self, source_image): + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + + return conditioning_image + + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -228,6 +233,9 @@ class StableDiffusionProcessing: if isinstance(self.sd_model, LatentDepth2ImageDiffusion): return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + if self.sd_model.cond_stage_key == "edit": + return self.edit_image_conditioning(source_image) + if self.sampler.conditioning_key in {'hybrid', 'concat'}: return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) -- cgit v1.2.1 From 6cff4401824299a983c8e13424018efc347b4a2b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:40 +0300 Subject: fix prompt editing break after first batch in img2img --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6261d1f7..a7910b56 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -454,7 +454,7 @@ class KDiffusionSampler: def initialize(self, p): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap.step = 0 + self.model_wrap_cfg.step = 0 self.eta = p.eta or opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) -- cgit v1.2.1 From e57b5f7c5560c49fbaf05e6bea326478222cb3e6 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 25 Jan 2023 22:36:14 -0500 Subject: re_param captures quotes with commas properly and removes unnecessary regex --- modules/generation_parameters_copypaste.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..13d0874d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,9 +11,8 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*(\"[^\"]*\"|[^,]+)' re_param = re.compile(re_param_code) -re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") type_of_gr_update = type(gr.update()) @@ -243,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model done_with_prompt = False *lines, lastline = x.strip().split("\n") - if not re_params.match(lastline): + if not re_param.match(lastline): lines.append(lastline) lastline = '' @@ -262,6 +261,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model res["Negative prompt"] = negative_prompt for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v m = re_imagesize.match(v) if m is not None: res[k+"-1"] = m.group(1) -- cgit v1.2.1 From 4d634dc592ffdbd4ebb2f1acfb9a63f5e26e4deb Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 26 Jan 2023 00:18:41 -0500 Subject: adds components to infotext_fields allows for loading script params --- modules/scripts.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..eefdfdd4 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -330,6 +330,20 @@ class ScriptRunner: outputs=[script.group for script in self.selectable_scripts] ) + self.script_load_ctr = 0 + def onload_script_visibility(params): + title = params.get('Script', None) + if title: + title_index = self.titles.index(title) + visibility = title_index == self.script_load_ctr + self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles) + return gr.update(visible=visibility) + else: + return gr.update(visible=False) + + self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) + self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) + return inputs def run(self, p: StableDiffusionProcessing, *args): -- cgit v1.2.1 From 10421f93c3f7f7ce88cb40391b46d4e6664eff74 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 00:34:38 -0500 Subject: Fix full previews, --no-half-vae --- modules/processing.py | 8 ++++---- modules/sd_hijack_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index cb41288a..92894d67 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), @@ -217,7 +217,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) return x @@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) + image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f81b169a..f8684475 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -5,7 +5,7 @@ class CondFunc: self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') - for i in range(len(func_path)-2, -1, -1): + for i in range(len(func_path)-1, -1, -1): try: resolved_obj = importlib.import_module('.'.join(func_path[:i])) break -- cgit v1.2.1 From f4ec411f2c9d6bc6817a2eca8a2c00f255ffb386 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 03:45:16 -0500 Subject: Allow checkpoint merger to merge pix2pix models in the same way that it currently supports inpainting models. --- modules/extras.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 36123aa5..67ffdee3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False + result_is_pix2pix_model = False if theta_func2: shared.state.textinfo = f"Loading B" @@ -186,13 +187,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True + if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... + print("Detected possible merge of instruct model with non-instruct model.") + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. + result_is_pix2pix_model = True + else: + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True else: theta_0[key] = theta_func2(a, b, multiplier) - + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 @@ -226,6 +231,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" + filename += ".pix2pix" if result_is_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) -- cgit v1.2.1 From f90798c6b6cc48e514acb08ce02bdb5874bf74d8 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 04:38:04 -0500 Subject: Added error check for the rare case a user merges a pix2pix model with a normal model using weighted sum. Also removed bad print message that interfered with merging progress bar. --- modules/extras.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 67ffdee3..badd13c7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -186,9 +186,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + if a.shape[1] == 4 and b.shape[1] == 8: + raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.") if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... - print("Detected possible merge of instruct model with non-instruct model.") theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. result_is_pix2pix_model = True else: -- cgit v1.2.1 From 9e72dc743480c8b1ca6aeb8ced3af03f3e3243a3 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 06:05:40 -0500 Subject: Changed all references to "pix2pix" to the more precise name "instruct pix2pix". Also changed extension to instrpix2pix at least for now. --- modules/extras.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index badd13c7..2bf0d17e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,7 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - result_is_pix2pix_model = False + result_is_instruct_pix2pix_model = False if theta_func2: shared.state.textinfo = f"Loading B" @@ -187,11 +187,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") if a.shape[1] == 4 and b.shape[1] == 8: - raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.") + raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.") - if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... + if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model... theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. - result_is_pix2pix_model = True + result_is_instruct_pix2pix_model = True else: assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) @@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" - filename += ".pix2pix" if result_is_pix2pix_model else "" + filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) -- cgit v1.2.1 From c4b9b07db6272768428fa8efeb7d7a9f22eca0b1 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 09:00:15 -0500 Subject: Fix embeddings dtype mismatch --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d21..531790f3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec + emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) -- cgit v1.2.1 From cdc2fa209a3efdc71a90643a5e7a1df49869cd5f Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 11:27:07 -0500 Subject: Changed filename addition from "instrpix2pix" to the more readable ".instruct-pix2pix" for newly generated instruct pix2pix models. --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 2bf0d17e..466ecc15 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" - filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else "" + filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) -- cgit v1.2.1 From 7a14c8ab45da8a681792a6331d48a88dd684a0a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 26 Jan 2023 23:29:27 +0300 Subject: add an option to enable sections from extras tab in txt2img/img2img fix some style inconsistenices --- modules/processing.py | 7 +++++- modules/scripts.py | 32 ++++++++++++++++++++++---- modules/scripts_auto_postprocessing.py | 42 ++++++++++++++++++++++++++++++++++ modules/scripts_postprocessing.py | 11 ++++++--- modules/shared.py | 15 ++++-------- modules/shared_items.py | 10 ++++++++ modules/ui_components.py | 8 +++++++ 7 files changed, 107 insertions(+), 18 deletions(-) create mode 100644 modules/scripts_auto_postprocessing.py create mode 100644 modules/shared_items.py (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 92894d67..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image(p, pp) + image = pp.image + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,12 +6,16 @@ from collections import namedtuple import gradio as gr -from modules.processing import StableDiffusionProcessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() +class PostprocessImageArgs: + def __init__(self, image): + self.image = image + + class Script: filename = None args_from = None @@ -65,7 +69,7 @@ class Script: args contains all values returned by components from ui() """ - raise NotImplementedError() + pass def process(self, p, *args): """ @@ -100,6 +104,13 @@ class Script: pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -247,11 +258,15 @@ class ScriptRunner: self.infotext_fields = [] def initialize_scripts(self, is_img2img): + from modules import scripts_auto_postprocessing + self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir, script_module in scripts_data: + auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() + + for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img @@ -332,7 +347,7 @@ class ScriptRunner: return inputs - def run(self, p: StableDiffusionProcessing, *args): + def run(self, p, *args): script_index = args[0] if script_index == 0: @@ -386,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_image(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image(p, pp, *script_args) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py new file mode 100644 index 00000000..30d6d658 --- /dev/null +++ b/modules/scripts_auto_postprocessing.py @@ -0,0 +1,42 @@ +from modules import scripts, scripts_postprocessing, shared + + +class ScriptPostprocessingForMainUI(scripts.Script): + def __init__(self, script_postproc): + self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc + self.postprocessing_controls = None + + def title(self): + return self.script.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.postprocessing_controls = self.script.ui() + return self.postprocessing_controls.values() + + def postprocess_image(self, p, script_pp, *args): + args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} + + pp = scripts_postprocessing.PostprocessedImage(script_pp.image) + pp.info = {} + self.script.process(pp, **args_dict) + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + + +def create_auto_preprocessing_script_data(): + from modules import scripts + + res = [] + + for name in shared.opts.postprocessing_enable_in_main_ui: + script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) + if script is None: + continue + + constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) + res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) + + return res diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 25de02d0..ce0ebb61 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -46,6 +46,8 @@ class ScriptPostprocessing: pass + + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) @@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: script: ScriptPostprocessing = script_class() script.filename = path + if script.name == "Simple Upscale": + continue + self.scripts.append(script) def create_script_ui(self, script, inputs): @@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: import modules.scripts self.initialize_scripts(modules.scripts.postprocessing_scripts_data) - scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + scripts_order = shared.opts.postprocessing_operation_order def script_score(name): - name = name.lower() for i, possible_match in enumerate(scripts_order): - if possible_match in name: + if possible_match == name: return i return len(self.scripts) @@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed() + diff --git a/modules/shared.py b/modules/shared.py index 6a0b96cb..cdeed55d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,8 +13,8 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components -from modules.paths import models_path, script_path, sd_path +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules.paths import models_path, script_path demo = None @@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] - -def realesrgan_models_names(): - import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] - - class OptionInfo: def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default @@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) @@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { - 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), })) diff --git a/modules/shared_items.py b/modules/shared_items.py new file mode 100644 index 00000000..b5d480c9 --- /dev/null +++ b/modules/shared_items.py @@ -0,0 +1,10 @@ + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + +def postprocessing_scripts(): + import modules.scripts + + return modules.scripts.scripts_postproc.scripts \ No newline at end of file diff --git a/modules/ui_components.py b/modules/ui_components.py index 9aec3097..284ca0cf 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + +class DropdownMulti(gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" -- cgit v1.2.1 From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: remove the need to place configs near models --- modules/api/api.py | 5 +- modules/devices.py | 12 ++- modules/sd_hijack_inpainting.py | 9 -- modules/sd_models.py | 228 ++++++++++++++++++++-------------------- modules/sd_models_config.py | 65 ++++++++++++ modules/shared.py | 7 +- modules/shared_items.py | 15 ++- modules/timer.py | 35 ++++++ 8 files changed, 242 insertions(+), 134 deletions(-) create mode 100644 modules/sd_models_config.py create mode 100644 modules/timer.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..eb7b1da5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, find_checkpoint_config +from modules.sd_models import checkpoints_list +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -387,7 +388,7 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..2d5f797a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 31d2c898..478cd499 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def should_hijack_inpainting(checkpoint_info): - from modules import sd_models - - ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() - - return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename - - def do_inpainting_hijack(): # p_sample_plms is needed because PLMS can't work with dicts as conditionings diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + timer.record("apply half()") - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 + do_inpainting_hijack() - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False + timer = Timer() - do_inpainting_hijack() + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - timer = Timer() + timer.record("find config") - sd_model = None + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py new file mode 100644 index 00000000..ea773a10 --- /dev/null +++ b/modules/sd_models_config.py @@ -0,0 +1,65 @@ +import re +import os + +from modules import shared, paths + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") + + +config_default = shared.sd_default_config +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") + +re_parametrization_v = re.compile(r'-v\b') + + +def guess_model_config_from_state_dict(sd, filename): + fn = os.path.basename(filename) + + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + return config_sd2v + else: + return config_sd2 + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if roberta_weight is not None: + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/shared.py b/modules/shared.py index cdeed55d..14be993d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,13 +13,14 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules import localization, extensions, script_loading, errors, ui_components, shared_items from modules.paths import models_path, script_path demo = None -sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") +sd_configs_path = os.path.join(script_path, "configs") +sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), diff --git a/modules/shared_items.py b/modules/shared_items.py index b5d480c9..8b5ec96d 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -4,7 +4,20 @@ def realesrgan_models_names(): import modules.realesrgan_model return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + def postprocessing_scripts(): import modules.scripts - return modules.scripts.scripts_postproc.scripts \ No newline at end of file + return modules.scripts.scripts_postproc.scripts + + +def sd_vae_items(): + import modules.sd_vae + + return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) + + +def refresh_vae_list(): + import modules.sd_vae + + return modules.sd_vae.refresh_vae_list diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 00000000..57a4f17a --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,35 @@ +import time + + +class Timer: + def __init__(self): + self.start = time.time() + self.records = {} + self.total = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def record(self, category, extra_time=0): + e = self.elapsed() + if category not in self.records: + self.records[category] = 0 + + self.records[category] += e + extra_time + self.total += e + extra_time + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [x for x in self.records.items() if x[1] >= 0.1] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res -- cgit v1.2.1 From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:54:19 +0300 Subject: support detecting midas model fix broken api for checkpoint list --- modules/api/models.py | 2 +- modules/sd_models.py | 10 +++++----- modules/sd_models_config.py | 7 +++++-- 3 files changed, 11 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 805bd8f7..cba43d3b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -228,7 +228,7 @@ class SDModelItem(BaseModel): hash: Optional[str] = Field(title="Short hash") sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") - config: str = Field(title="Config file") + config: Optional[str] = Field(title="Config file") class HypernetworkItem(BaseModel): name: str = Field(title="Name") diff --git a/modules/sd_models.py b/modules/sd_models.py index fa208728..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) + sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index ea773a10..4d1e92e1 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") @@ -22,7 +23,9 @@ def guess_model_config_from_state_dict(sd, filename): sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) - roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: @@ -36,7 +39,7 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if roberta_weight is not None: + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: return config_alt_diffusion return config_default -- cgit v1.2.1 From 9beb794e0b0dc1a0f9e89d8e38bd789a8c608397 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 13:08:00 +0300 Subject: clarify the option to disable NaN check. --- modules/devices.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 2d5f797a..4687944e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -143,6 +143,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) -- cgit v1.2.1 From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:15:42 +0100 Subject: add data-dir flag and set all user data directories based on it --- modules/extensions.py | 2 +- modules/generation_parameters_copypaste.py | 4 ++-- modules/gfpgan_model.py | 5 ++--- modules/hashes.py | 4 +++- modules/interrogate.py | 2 +- modules/paths.py | 10 +++++++++- modules/processing.py | 3 ++- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 5 ++--- modules/shared.py | 11 ++++++----- modules/textual_inversion/preprocess.py | 5 ++--- modules/ui.py | 6 +++--- modules/ui_extensions.py | 2 +- modules/upscaler.py | 5 ++--- 14 files changed, 39 insertions(+), 31 deletions(-) (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index b522125c..92ee8144 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,7 @@ import git from modules import paths, shared extensions = [] -extensions_dir = os.path.join(paths.script_path, "extensions") +extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..35f72808 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.shared import script_path +from modules.paths import data_path, script_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: - filename = os.path.join(script_path, "params.txt") + filename = os.path.join(data_path, "params.txt") if os.path.exists(filename): with open(filename, "r", encoding="utf8") as file: prompt = file.read() diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 1e2dbc32..fbe6215a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -6,12 +6,11 @@ import facexlib import gfpgan import modules.face_restoration -from modules import shared, devices, modelloader -from modules.paths import models_path +from modules import paths, shared, devices, modelloader model_dir = "GFPGAN" user_path = None -model_path = os.path.join(models_path, model_dir) +model_path = os.path.join(paths.models_path, model_dir) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None diff --git a/modules/hashes.py b/modules/hashes.py index b85a7580..819362a3 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -4,8 +4,10 @@ import os.path import filelock +from modules.paths import data_path -cache_filename = "cache.json" + +cache_filename = os.path.join(data_path, "cache.json") cache_data = None diff --git a/modules/interrogate.py b/modules/interrogate.py index c72ff694..cbb80683 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -12,7 +12,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader, errors +from modules import devices, paths, shared, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' diff --git a/modules/paths.py b/modules/paths.py index 20b3e4d8..08e6f9b9 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -4,7 +4,15 @@ import sys import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -models_path = os.path.join(script_path, "models") + +# Parse the --data-dir flag first so we can use it as a base for our other argument default values +parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) +cmd_opts_pre = parser.parse_known_args()[0] +data_path = cmd_opts_pre.data_dir +models_path = os.path.join(data_path, "models") + +# data_path = cmd_opts_pre.data sys.path.insert(0, script_path) # search for directory of stable diffusion in following places diff --git a/modules/processing.py b/modules/processing.py index 262806a1..5072fc40 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared +import modules.paths as paths import modules.face_restoration import modules.images as images import modules.styles @@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) diff --git a/modules/sd_models.py b/modules/sd_models.py index 37dad18d..b2d48a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,13 +12,13 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_alisases = {} @@ -307,7 +307,7 @@ def enable_midas_autodownload(): location automatically. """ - midas_path = os.path.join(models_path, 'midas') + midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 4ce238b8..9b00f76e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,13 +3,12 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks, sd_models -from modules.paths import models_path +from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy -vae_path = os.path.abspath(os.path.join(models_path, "VAE")) +vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_dict = {} diff --git a/modules/shared.py b/modules/shared.py index 14be993d..474fcc42 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.styles import modules.devices as devices from modules import localization, extensions, script_loading, errors, ui_components, shared_items -from modules.paths import models_path, script_path +from modules.paths import models_path, script_path, data_path demo = None @@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") @@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") -parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") @@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) -parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index c0ac11d3..2239cb84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -6,8 +6,7 @@ import sys import tqdm import time -from modules import shared, images, deepbooru -from modules.paths import models_path +from modules import paths, shared, images, deepbooru from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop @@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..0117df3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts, restricted_opts @@ -1497,8 +1497,8 @@ def create_ui(): with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 742e745e..66a41865 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url): normalized_url = normalize_git_url(url) assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' - tmpdir = os.path.join(paths.script_path, "tmp", dirname) + tmpdir = os.path.join(paths.data_path, "tmp", dirname) try: shutil.rmtree(tmpdir, True) diff --git a/modules/upscaler.py b/modules/upscaler.py index a5bf5acb..e2eaa730 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -11,7 +11,6 @@ from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) -from modules.paths import models_path class Upscaler: @@ -39,7 +38,7 @@ class Upscaler: self.mod_scale = None if self.model_path is None and self.name: - self.model_path = os.path.join(models_path, self.name) + self.model_path = os.path.join(shared.models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) @@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler): def __init__(self, dirname=None): super().__init__(False) self.name = "Nearest" - self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file + self.scalers = [UpscalerData("Nearest", None, self)] -- cgit v1.2.1 From 14c0884fd0948c478db165989cca7aaffc9a0504 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:55:59 +0100 Subject: use python importlib to load and execute extension modules previously module attributes like __file__ where not set correctly, leading to scripts getting the directory of the stable-diffusion repo location instead of their own script. This causes problem when loading user data from an external location using the --data-dir flag, as extensions would look for their own code in the stable-diffusion repo location instead of the data dir location. Using pythons importlib functions sets the modules specs correctly and executes them. But this will break extensions if they build paths based on the previously incorrect __file__ attribute. --- modules/script_loading.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/script_loading.py b/modules/script_loading.py index f93f0951..a7d2203f 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,16 +1,14 @@ import os import sys import traceback +import importlib.util from types import ModuleType def load_module(path): - with open(path, "r", encoding="utf8") as file: - text = file.read() - - compiled = compile(text, path, 'exec') - module = ModuleType(os.path.basename(path)) - exec(compiled, module.__dict__) + module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) return module -- cgit v1.2.1 From 6b3981c0685cd1df750df4eb51823f1cfd70c6d5 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:00:09 +0100 Subject: clean up unused script_path imports --- modules/codeformer_model.py | 2 +- modules/generation_parameters_copypaste.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ab40d842..01fb7bd8 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,7 +8,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader -from modules.paths import script_path, models_path +from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 35f72808..773c5c0e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.paths import data_path, script_path +from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -- cgit v1.2.1 From 23a9d5e27390846dea0895a02c04aec9583a4d38 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:18:55 +0100 Subject: create user extensions directory if not exists --- modules/extensions.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py index 92ee8144..5e12b1aa 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -10,6 +10,8 @@ extensions = [] extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") +if not os.path.exists(extensions_dir): + os.makedirs(extensions_dir) def active(): return [x for x in extensions if x.enabled] -- cgit v1.2.1 From eafaf14167cf574ad0f918c10f60ef86aea9cd20 Mon Sep 17 00:00:00 2001 From: Gazzoo-byte <73721238+Gazzoo-byte@users.noreply.github.com> Date: Fri, 27 Jan 2023 18:34:41 +0000 Subject: Add button to switch width and height Adds a button to switch width and height, allowing quick and easy switching between landscape and portrait. --- modules/ui.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..fb0e4d5c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -91,6 +91,13 @@ save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 +switch_values_symbol = '\U000021C5' # ⇅ + +def switch_width_and_height(width, height): + width_temp = width + width = height + height = width_temp + return width, height def plaintext_to_html(text): @@ -466,6 +473,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") if opts.dimensions_and_batch_together: + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") with gr.Column(elem_id="txt2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") @@ -566,6 +574,7 @@ def create_ui(): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) + res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) txt_prompt_img.change( fn=modules.images.image_data, @@ -728,6 +737,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") if opts.dimensions_and_batch_together: + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") with gr.Column(elem_id="img2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") @@ -865,6 +875,7 @@ def create_ui(): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) + res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), -- cgit v1.2.1 From cc8c9b7474d917888a0bd069fcd59a458c67ae4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 22:43:08 +0300 Subject: fix broken calls to find_checkpoint_config --- modules/extras.py | 4 ++-- modules/sd_hijack_ip2p.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 36123aa5..4f842be9 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -6,7 +6,7 @@ import shutil import torch import tqdm -from modules import shared, images, sd_models, sd_vae +from modules import shared, images, sd_models, sd_vae, sd_models_config from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -37,7 +37,7 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - res = sd_models.find_checkpoint_config(x) if x else None + res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None return res if res != shared.sd_default_config else None if config_source == 0: diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py index 635f015f..3c727d3b 100644 --- a/modules/sd_hijack_ip2p.py +++ b/modules/sd_hijack_ip2p.py @@ -5,9 +5,9 @@ import gc import time def should_hijack_ip2p(checkpoint_info): - from modules import sd_models + from modules import sd_models_config ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename -- cgit v1.2.1 From 6b82efd737827bbeef202f04ff5a8faec9b64ef8 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 27 Jan 2023 20:06:19 -0500 Subject: add v2-inpainting model detection, and broaden v-model detection to include anything with 768 in the name --- modules/sd_models_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 4d1e92e1..73854a45 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") @@ -28,7 +29,9 @@ def guess_model_config_from_state_dict(sd, filename): return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: - if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + if diffusion_model_input.shape[1] == 9: + return config_sd2_inpainting + elif re.search(re_parametrization_v, fn) or "768" in fn: return config_sd2v else: return config_sd2 -- cgit v1.2.1 From 2aac1d97782b486f3a4a5209cf399dcdcb7bbb4d Mon Sep 17 00:00:00 2001 From: Andrii Skaliuk Date: Fri, 27 Jan 2023 17:32:31 -0800 Subject: Basic inpainting batch support Modifies batch UI to add optional inpainting support --- modules/img2img.py | 20 +++++++++++++++++--- modules/ui.py | 9 ++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 2168c8e2..fe9447c7 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -16,11 +16,16 @@ import modules.images as images import modules.scripts -def process_batch(p, input_dir, output_dir, args): +def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) images = shared.listfiles(input_dir) + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0 + if is_inpaint_batch: + print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") + print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") save_normally = output_dir == '' @@ -43,6 +48,15 @@ def process_batch(p, input_dir, output_dir, args): img = ImageOps.exif_transpose(img) p.init_images = [img] * p.batch_size + if is_inpaint_batch: + # try to find corresponding mask for an image using simple filename matching + mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) + # if not found use first one ("same mask for all images" use-case) + if not mask_image_path in inpaint_masks: + mask_image_path = inpaint_masks[0] + mask_image = Image.open(mask_image_path) + p.image_mask = mask_image + proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: proc = process_images(p) @@ -59,7 +73,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img @@ -139,7 +153,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..fddb9177 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -691,9 +691,15 @@ def create_ui(): with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + gr.HTML( + f"

Process images in a directory on the same machine where the server is running." + + f"
Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

" + ) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") def copy_image(img): if isinstance(img, dict) and 'image' in img: @@ -838,6 +844,7 @@ def create_ui(): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, + img2img_batch_inpaint_mask_dir ] + custom_inputs, outputs=[ img2img_gallery, -- cgit v1.2.1 From 4c52dfe4ac98c53431ecd267d59f27391d3a63e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 08:30:17 +0300 Subject: make the detection for -v models less broad --- modules/sd_models_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 73854a45..00217990 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -31,7 +31,7 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif re.search(re_parametrization_v, fn) or "768" in fn: + elif re.search(re_parametrization_v, fn): return config_sd2v else: return config_sd2 -- cgit v1.2.1 From 0834d4ce374225131e025540220c727e352a3e43 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 08:41:15 +0300 Subject: simplify #7284 --- modules/ui.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 3c0a4050..ca2c1eb6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -93,12 +93,6 @@ clear_prompt_symbol = '\U0001F5D1' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ -def switch_width_and_height(width, height): - width_temp = width - width = height - height = width_temp - return width, height - def plaintext_to_html(text): return ui_common.plaintext_to_html(text) @@ -574,7 +568,8 @@ def create_ui(): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) - res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) + + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) txt_prompt_img.change( fn=modules.images.image_data, @@ -882,7 +877,7 @@ def create_ui(): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) - res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), -- cgit v1.2.1 From 3752aad23d4be4522f9edf3fe79c1122fa5ad509 Mon Sep 17 00:00:00 2001 From: Mackerel Date: Sat, 28 Jan 2023 02:44:12 -0500 Subject: don't replace regular --help with new paths.py parser help --- modules/paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/paths.py b/modules/paths.py index 08e6f9b9..d991cc71 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -6,7 +6,7 @@ import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # Parse the --data-dir flag first so we can use it as a base for our other argument default values -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) cmd_opts_pre = parser.parse_known_args()[0] data_path = cmd_opts_pre.data_dir -- cgit v1.2.1 From ada17dbd7c4c68a4e559848d2e6f2a7799722806 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 27 Jan 2023 10:19:43 -0500 Subject: Refactor conditional casting, fix upscalers --- modules/devices.py | 8 ++++++++ modules/processing.py | 15 ++++++++------- modules/realesrgan_model.py | 2 +- modules/sd_hijack.py | 2 +- modules/sd_hijack_unet.py | 8 +++++++- 5 files changed, 25 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..0100e4af 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -83,6 +83,14 @@ dtype_unet = torch.float16 unet_needs_upcast = False +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + def randn(seed, shape): torch.manual_seed(seed) if device.type == 'mps': diff --git a/modules/processing.py b/modules/processing.py index 92894d67..a397702b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,8 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) - conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -217,7 +216,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -228,16 +227,18 @@ class StableDiffusionProcessing: return image_conditioning def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + source_image = devices.cond_cast_float(source_image) + # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + return self.depth2img_image_conditioning(source_image) if self.sd_model.cond_stage_key == "edit": return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -417,7 +418,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) + x = model.decode_first_stage(x) return x @@ -1001,7 +1002,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) + image = image.to(shared.device) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 47f70251..aad4a629 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler): scale=info.scale, model_path=info.local_data_path, model=info.model(), - half=not cmd_opts.no_half, + half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, tile=opts.ESRGAN_tile, tile_pad=opts.ESRGAN_tile_overlap, ) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 531790f3..8fc91882 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec + emb = devices.cond_cast_unet(embedding.vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index a6ee577c..45cf2b18 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.1"): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) + +first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -- cgit v1.2.1 From f9edd578e9e29d160e6d56038bb368dc49895d64 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 28 Jan 2023 00:20:30 -0500 Subject: Remove MPS fix no longer needed for PyTorch The torch.narrow fix was required for nightly PyTorch builds for a while to prevent a hard crash, but newer nightly builds don't have this issue. --- modules/devices.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 0100e4af..be542f8f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -201,6 +201,3 @@ if has_mps(): cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) - orig_narrow = torch.narrow - torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) - -- cgit v1.2.1 From 4aa7f5b5b996c1e3d97640e746f040a23a124860 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 11:11:47 +0300 Subject: update image parameters regex for #7231 --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 773c5c0e..1bf35bbb 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,7 +11,7 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") -- cgit v1.2.1 From d04e3e921e8ee71442a1f4a1d6e91c05b8238007 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 15:24:29 +0300 Subject: automatically detect v-parameterization for SD2 checkpoints --- modules/sd_hijack.py | 2 ++ modules/sd_models_config.py | 51 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d21..03897b2a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -131,6 +131,8 @@ class StableDiffusionModelHijack: m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped m.cond_stage_model = m.cond_stage_model.wrapped + undo_optimizations() + self.apply_circular(False) self.layers = None self.clip = None diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 00217990..91c21700 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,7 +1,9 @@ import re import os -from modules import shared, paths +import torch + +from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") @@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml" config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") -re_parametrization_v = re.compile(r'-v\b') +def is_using_v_parameterization_for_sd2(state_dict): + """ + Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. + """ -def guess_model_config_from_state_dict(sd, filename): - fn = os.path.basename(filename) + import ldm.modules.diffusionmodules.openaimodel + from modules import devices + + device = devices.cpu + + with sd_disable_initialization.DisableInitialization(): + unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( + use_checkpoint=True, + use_fp16=False, + image_size=32, + in_channels=4, + out_channels=4, + model_channels=320, + attention_resolutions=[4, 2, 1], + num_res_blocks=2, + channel_mult=[1, 2, 4, 4], + num_head_channels=64, + use_spatial_transformer=True, + use_linear_in_transformer=True, + transformer_depth=1, + context_dim=1024, + legacy=False + ) + unet.eval() + + with torch.no_grad(): + unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} + unet.load_state_dict(unet_sd, strict=True) + unet.to(device=device, dtype=torch.float) + test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 + x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 + + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() + + return out < -1 + + +def guess_model_config_from_state_dict(sd, filename): sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) @@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif re.search(re_parametrization_v, fn): + elif is_using_v_parameterization_for_sd2(sd): return config_sd2v else: return config_sd2 -- cgit v1.2.1 From f8feeaaedb890de1e36eeb2ad387f0eb3abafd54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 15:57:56 +0300 Subject: add progressbar to extension update check; do not check for updates for disabled extensions --- modules/ui_extensions.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 66a41865..37d30e1f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -13,7 +13,7 @@ import shutil import errno from modules import extensions, shared, paths - +from modules.call_queue import wrap_gradio_gpu_call available_extensions = {"extensions": []} @@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list): shared.state.need_restart = True -def check_updates(): +def check_updates(id_task, disable_list): check_access() - for ext in extensions.extensions: - if ext.remote is None: - continue + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled] + shared.state.job_count = len(exts) + + for ext in exts: + shared.state.textinfo = ext.name try: ext.check_updates() @@ -63,7 +68,9 @@ def check_updates(): print(f"Error checking updates for {ext.name}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - return extension_table() + shared.state.nextjob() + + return extension_table(), "" def extension_table(): @@ -273,12 +280,13 @@ def create_ui(): with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.TabItem("Installed"): - with gr.Row(): + with gr.Row(elem_id="extensions_installed_top"): apply = gr.Button(value="Apply and restart UI", variant="primary") check = gr.Button(value="Check for updates") extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + info = gr.HTML() extensions_table = gr.HTML(lambda: extension_table()) apply.click( @@ -289,10 +297,10 @@ def create_ui(): ) check.click( - fn=check_updates, + fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]), _js="extensions_check", - inputs=[], - outputs=[extensions_table], + inputs=[info, extensions_disabled_list], + outputs=[extensions_table, info], ) with gr.TabItem("Available"): -- cgit v1.2.1 From 5d14f282c2812888275902be4b552681f942dbfd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 16:23:49 +0300 Subject: fixed a bug where after switching to a checkpoint with unknown hash, you'd get empty space instead of checkpoint name in UI fixed a bug where if you update a selected checkpoint on disk and then restart the program, a different checkpoint loads, but the name is shown for the the old one. --- modules/sd_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index b2d48a51..c45ddf83 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -231,12 +231,10 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): - title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if checkpoint_info.title != title: - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) -- cgit v1.2.1 From 1421e959600e0e9a2435e48373a551237bbab814 Mon Sep 17 00:00:00 2001 From: Thurion Date: Sat, 28 Jan 2023 14:42:24 +0100 Subject: allow empty mask dir --- modules/img2img.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index fe9447c7..3ecb6146 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): images = shared.listfiles(input_dir) - inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0 + is_inpaint_batch = False + if inpaint_mask_dir: + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = len(inpaint_masks) > 0 if is_inpaint_batch: print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") -- cgit v1.2.1 From b7d2af8c7fa48d6eef7517a6fbc63a3507c638d4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 17:18:47 +0300 Subject: add dropdowns in settings for hypernets and loras --- modules/extra_networks_hypernet.py | 8 +++++++- modules/shared.py | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index ff279a1f..d3a4d7ad 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -1,4 +1,4 @@ -from modules import extra_networks +from modules import extra_networks, shared, extra_networks from modules.hypernetworks import hypernetwork @@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): super().__init__('hypernet') def activate(self, p, params_list): + additional = shared.opts.sd_hypernetwork + + if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + p.all_prompts = [x + f"" for x in p.all_prompts] + params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) + names = [] multipliers = [] for params in params_list: diff --git a/modules/shared.py b/modules/shared.py index 474fcc42..eb04e811 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -405,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) @@ -431,7 +430,9 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('extra_networks', "Extra Networks"), { - "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), })) options_templates.update(options_section(('ui', "User interface"), { -- cgit v1.2.1 From 591b68e56c53eed391d08ce008423191c573784d Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 28 Jan 2023 10:04:09 -0500 Subject: uses autos new regex, checks len of re_param --- modules/generation_parameters_copypaste.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 13d0874d..53f1a865 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -11,7 +11,7 @@ from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image -re_param_code = r'\s*([\w ]+):\s*(\"[^\"]*\"|[^,]+)' +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") @@ -242,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model done_with_prompt = False *lines, lastline = x.strip().split("\n") - if not re_param.match(lastline): + if len(re_param.findall(lastline)) < 3: lines.append(lastline) lastline = '' -- cgit v1.2.1 From e2c71a4bd41470b9503021db36be2ae65f345d97 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 18:12:53 +0300 Subject: make prevent the browser from using cached version of scripts when they change --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9f4cfda1..4e082408 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1692,14 +1692,14 @@ def create_ui(): def reload_javascript(): - head = f'\n' + head = f'\n' inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: inline += f"set_theme('{cmd_opts.theme}');" for script in modules.scripts.list_scripts("javascript", ".js"): - head += f'\n' + head += f'\n' head += f'\n' -- cgit v1.2.1 From 1d8e06d542176beade37d2d36cb57460c3c6772f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 22:52:27 +0300 Subject: add checkpoints tab for extra networks UI --- modules/ui.py | 8 ++++++ modules/ui_extra_networks.py | 37 ++++++++++++++++++++++--- modules/ui_extra_networks_checkpoints.py | 38 ++++++++++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 2 +- modules/ui_extra_networks_textual_inversion.py | 2 +- 5 files changed, 81 insertions(+), 6 deletions(-) create mode 100644 modules/ui_extra_networks_checkpoints.py (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 4e082408..f1195692 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1560,6 +1560,14 @@ def create_ui(): outputs=[component, text_settings], ) + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[component_dict['sd_model_checkpoint'], dummy_component], + outputs=[component_dict['sd_model_checkpoint'], text_settings], + ) + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] def get_settings_values(): diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c6ff889a..5730c879 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,4 +1,6 @@ import os.path +import urllib.parse +from pathlib import Path from modules import shared import gradio as gr @@ -8,12 +10,31 @@ import html from modules.generation_parameters_copypaste import image_from_url_text extra_pages = [] +allowed_dirs = set() def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" extra_pages.append(page) + allowed_dirs.clear() + allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) + + +def add_pages_to_demo(app): + def fetch_file(filename: str = ""): + from starlette.responses import FileResponse + + if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]): + raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + if os.path.splitext(filename)[1].lower() != ".png": + raise ValueError(f"File cannot be fetched: {filename}. Only png.") + + # would profit from returning 304 + return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) + + app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) class ExtraNetworksPage: @@ -26,6 +47,9 @@ class ExtraNetworksPage: def refresh(self): pass + def link_preview(self, filename): + return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + def create_html(self, tabname): view = shared.opts.extra_networks_default_view items_html = '' @@ -54,13 +78,17 @@ class ExtraNetworksPage: def create_html_for_item(self, item, tabname): preview = item.get("preview", None) + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + args = { "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', - "prompt": item["prompt"], + "prompt": item.get("prompt", None), "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } @@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path): parent_path = os.path.abspath(parent_path) child_path = os.path.abspath(child_path) - return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + return child_path.startswith(parent_path) def setup_ui(ui, gallery): @@ -173,7 +201,8 @@ def setup_ui(ui, gallery): ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], outputs=[*ui.pages] ) + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py new file mode 100644 index 00000000..c66cb830 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints.py @@ -0,0 +1,38 @@ +import html +import json +import os +import urllib.parse + +from modules import shared, ui_extra_networks, sd_models + + +class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Checkpoints') + + def refresh(self): + shared.refresh_checkpoints() + + def list_items(self): + for name, checkpoint1 in sd_models.checkpoints_list.items(): + checkpoint: sd_models.CheckpointInfo = checkpoint1 + path, ext = os.path.splitext(checkpoint.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = self.link_preview(file) + break + + yield { + "name": checkpoint.model_name, + "filename": path, + "preview": preview, + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.ckpt_dir, sd_models.model_path] + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 65d000cf..8c15f8eb 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -19,7 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): preview = None for file in previews: if os.path.isfile(file): - preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + preview = self.link_preview(file) break yield { diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index dbd23d2d..a9d3064b 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -19,7 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): preview = None if os.path.isfile(preview_file): - preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + preview = self.link_preview(preview_file) yield { "name": embedding.name, -- cgit v1.2.1 From 09a142a05a6da8bdd4f36678a098c2a573db181a Mon Sep 17 00:00:00 2001 From: glop102 Date: Sat, 28 Jan 2023 19:25:52 -0500 Subject: Reduce grid rows if larger than number of images available When a set number of grid rows is specified in settings, then it leads to situations where an entire row in the grid is empty. The most noticable example is the processing preview when the row count is set to 2, where it shows the preview just fine but with a black rectangle under it. --- modules/images.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 0bc3d524..ae3cdaf4 100644 --- a/modules/images.py +++ b/modules/images.py @@ -36,6 +36,8 @@ def image_grid(imgs, batch_size=1, rows=None): else: rows = math.sqrt(len(imgs)) rows = round(rows) + if rows > len(imgs): + rows = len(imgs) cols = math.ceil(len(imgs) / rows) -- cgit v1.2.1 From f6b7768f84a335d351ba8c0d4c34d78e59272339 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 10:20:19 +0300 Subject: support for searching subdirectory names for extra networks --- modules/sd_models.py | 1 + modules/ui_extra_networks.py | 11 +++++++++++ modules/ui_extra_networks_checkpoints.py | 6 +++--- modules/ui_extra_networks_hypernets.py | 1 + modules/ui_extra_networks_textual_inversion.py | 1 + 5 files changed, 17 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index c45ddf83..300387a9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,6 +41,7 @@ class CheckpointInfo: name = name[1:] self.name = name + self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 5730c879..29c6e196 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -50,6 +50,16 @@ class ExtraNetworksPage: def link_preview(self, filename): return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) + def search_terms_from_path(self, filename, possible_directories=None): + abspath = os.path.abspath(filename) + + for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): + parentdir = os.path.abspath(parentdir) + if abspath.startswith(parentdir): + return abspath[len(parentdir):].replace('\\','/') + + return "" + def create_html(self, tabname): view = shared.opts.extra_networks_default_view items_html = '' @@ -90,6 +100,7 @@ class ExtraNetworksPage: "name": item["name"], "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "search_term": item.get("search_term", ""), } return self.card_page.format(**args) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index c66cb830..360579b0 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -14,8 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def list_items(self): - for name, checkpoint1 in sd_models.checkpoints_list.items(): - checkpoint: sd_models.CheckpointInfo = checkpoint1 + for name, checkpoint in sd_models.checkpoints_list.items(): path, ext = os.path.splitext(checkpoint.filename) previews = [path + ".png", path + ".preview.png"] @@ -26,9 +25,10 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): break yield { - "name": checkpoint.model_name, + "name": checkpoint.name_for_extra, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(checkpoint.filename), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8c15f8eb..57851088 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -26,6 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, + "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index a9d3064b..bb64eb81 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -25,6 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": preview, + "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", } -- cgit v1.2.1 From 659d602dce42608a664642021ea2441da045d189 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 29 Jan 2023 02:32:53 -0500 Subject: only returns ckpt directories if they are not none --- modules/ui_extra_networks_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index c66cb830..5b471671 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -34,5 +34,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): } def allowed_directories_for_previews(self): - return [shared.cmd_opts.ckpt_dir, sd_models.model_path] + return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] -- cgit v1.2.1 From 8d7382ab24756cdcc37e71406832814f4713c55e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:34:58 +0300 Subject: add buttons for auto-search in subdirectories for extra tabs --- modules/ui_extra_networks.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 29c6e196..83367968 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,3 +1,4 @@ +import glob import os.path import urllib.parse from pathlib import Path @@ -56,7 +57,7 @@ class ExtraNetworksPage: for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): parentdir = os.path.abspath(parentdir) if abspath.startswith(parentdir): - return abspath[len(parentdir):].replace('\\','/') + return abspath[len(parentdir):].replace('\\', '/') return "" @@ -64,6 +65,27 @@ class ExtraNetworksPage: view = shared.opts.extra_networks_default_view items_html = '' + subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: + for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): + if not os.path.isdir(x): + continue + + subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") + while subdir.startswith("/"): + subdir = subdir[1:] + + subdirs[subdir] = 1 + + if subdirs: + subdirs = {"": 1, **subdirs} + + subdirs_html = "".join([f""" + +""" for subdir in subdirs]) + for item in self.list_items(): items_html += self.create_html_for_item(item, tabname) @@ -72,6 +94,9 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" +
+{subdirs_html} +
{items_html}
-- cgit v1.2.1 From aa6e55e00140da6d73d3d360a5628c1b1316550d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:53:05 +0300 Subject: do not display the message for TI unless the list of loaded embeddings changed --- modules/textual_inversion/textual_inversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6cf00e65..a1a406c2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -112,6 +112,7 @@ class EmbeddingDatabase: self.skipped_embeddings = {} self.expected_shape = -1 self.embedding_dirs = {} + self.previously_displayed_embeddings = () def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) @@ -228,9 +229,12 @@ class EmbeddingDatabase: self.load_from_dir(embdir) embdir.update() - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: - print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") + displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) + if self.previously_displayed_embeddings != displayed_embeddings: + self.previously_displayed_embeddings = displayed_embeddings + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] -- cgit v1.2.1 From 00dab8f10defbbda579a1bc89c8d4e972c58a20d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 29 Jan 2023 11:53:24 +0300 Subject: remove Batch size and Batch pos from textinfo (goodbye) --- modules/processing.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index afab6790..2d295932 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -450,8 +450,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Batch size": (None if p.batch_size < 2 else p.batch_size), - "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), -- cgit v1.2.1