aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 7888d864..1bb25adf 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -89,11 +89,12 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
restart_steps, restart_times, restart_max = restart_list[i + 1]
min_idx = i + 1
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
- sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end
- for times in range(restart_times):
- x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5
- for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]):
- x = heun_step(x, old_sigma, new_sigma)
+ if max_idx < min_idx:
+ sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end
+ for times in range(restart_times):
+ x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5
+ for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]):
+ x = heun_step(x, old_sigma, new_sigma)
return x
samplers_data_k_diffusion = [