aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e9e41818..55f0d3a3 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -8,6 +8,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -160,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -180,6 +181,11 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ if after_cfg_callback_params.output_altered:
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised