aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sd_samplers.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4c123d3b..1a1b8919 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -454,6 +454,9 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
+ if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
@@ -494,6 +497,9 @@ class KDiffusionSampler:
x = x * sigmas[0]
+ if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
extra_params_kwargs = self.initialize(p)
if 'sigma_min' in inspect.signature(self.func).parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()