aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorlambertae <dengm@mit.edu>2023-07-25 22:35:43 -0400
committerlambertae <dengm@mit.edu>2023-07-25 22:35:43 -0400
commit8de6d3ff77e841a5fd9d5f1b16bdd22737c8d657 (patch)
tree45caa3ef442b5f90857652b6615169141b650382 /modules/sd_samplers_kdiffusion.py
parentf87389029839a27464a18846815339e81787b882 (diff)
fix progress bar & torchHijack
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index ed60670c..7a2427b5 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -79,19 +79,26 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
for key, value in restart_list.items():
temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value
restart_list = temp_list
- for i in trange(len(sigmas) - 1, disable=disable):
- x = heun_step(x, sigmas[i], sigmas[i+1])
+ step_list = []
+ for i in range(len(sigmas) - 1):
+ step_list.append((sigmas[i], sigmas[i + 1]))
if i + 1 in restart_list:
restart_steps, restart_times, restart_max = restart_list[i + 1]
min_idx = i + 1
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
if max_idx < min_idx:
- sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] # remove the zero at the end
+ sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0:
restart_times -= 1
- x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5
- for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]):
- x = heun_step(x, old_sigma, new_sigma)
+ step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
+ last_sigma = None
+ for i in trange(len(step_list), disable=disable):
+ if last_sigma is None:
+ last_sigma = step_list[i][0]
+ elif last_sigma < step_list[i][0]:
+ x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (step_list[i][0] ** 2 - last_sigma ** 2) ** 0.5
+ x = heun_step(x, step_list[i][0], step_list[i][1])
+ last_sigma = step_list[i][1]
return x
samplers_data_k_diffusion = [