From 42082e8a3239c1c32cd9e2a03a20b610af857b51 Mon Sep 17 00:00:00 2001 From: devdn Date: Tue, 28 Mar 2023 18:18:28 -0400 Subject: performance increase --- modules/sd_samplers_kdiffusion.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e9f08518..6a54ce32 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module): return denoised - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -116,6 +116,12 @@ class CFGDenoiser(torch.nn.Module): tensor = denoiser_params.text_cond uncond = denoiser_params.text_uncond + sigma_thresh = s_min_uncond + if(torch.dot(sigma,sigma) < sigma.shape[0] * (sigma_thresh*sigma_thresh) and not is_edit_model): + uncond = torch.zeros([0,0,uncond.shape[2]]) + x_in=x_in[:x_in.shape[0]//2] + sigma_in=sigma_in[:sigma_in.shape[0]//2] + if tensor.shape[1] == uncond.shape[1]: if not is_edit_model: cond_in = torch.cat([tensor, uncond]) @@ -144,7 +150,8 @@ class CFGDenoiser(torch.nn.Module): x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + if uncond.shape[0]: + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) cfg_denoised_callback(denoised_params) @@ -157,7 +164,10 @@ class CFGDenoiser(torch.nn.Module): sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) if not is_edit_model: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + if uncond.shape[0]: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + else: + denoised = x_out else: denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) @@ -165,7 +175,6 @@ class CFGDenoiser(torch.nn.Module): denoised = self.init_latent * self.mask + self.nmask * denoised self.step += 1 - return denoised @@ -244,6 +253,7 @@ class KDiffusionSampler: self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else opts.eta_ancestral + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) @@ -326,6 +336,7 @@ class KDiffusionSampler: 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond } samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) @@ -359,7 +370,8 @@ class KDiffusionSampler: 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.1 From 44e8e9c36807d4a71c2fc84129ebcf5ba4f77f21 Mon Sep 17 00:00:00 2001 From: devdn Date: Thu, 30 Mar 2023 00:54:28 -0400 Subject: fix live preview & alternate uncond guidance for better quality --- modules/sd_samplers_kdiffusion.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 6a54ce32..17d24df4 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -116,11 +116,13 @@ class CFGDenoiser(torch.nn.Module): tensor = denoiser_params.text_cond uncond = denoiser_params.text_uncond - sigma_thresh = s_min_uncond - if(torch.dot(sigma,sigma) < sigma.shape[0] * (sigma_thresh*sigma_thresh) and not is_edit_model): - uncond = torch.zeros([0,0,uncond.shape[2]]) - x_in=x_in[:x_in.shape[0]//2] - sigma_in=sigma_in[:sigma_in.shape[0]//2] + if self.step % 2 and s_min_uncond > 0 and not is_edit_model: + # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it + sigma_threshold = s_min_uncond + if(torch.dot(sigma,sigma) < sigma.shape[0] * (sigma_threshold*sigma_threshold) ): + uncond = torch.zeros([0,0,uncond.shape[2]]) + x_in=x_in[:x_in.shape[0]//2] + sigma_in=sigma_in[:sigma_in.shape[0]//2] if tensor.shape[1] == uncond.shape[1]: if not is_edit_model: @@ -159,7 +161,7 @@ class CFGDenoiser(torch.nn.Module): devices.test_for_nans(x_out, "unet") if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) + sd_samplers_common.store_latent(x_out[0:x_out.shape[0]-uncond.shape[0]]) elif opts.live_preview_content == "Negative prompt": sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) -- cgit v1.2.1