aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_timesteps_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_timesteps_impl.py')
-rw-r--r--modules/sd_samplers_timesteps_impl.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
index d32e3521..a72daafd 100644
--- a/modules/sd_samplers_timesteps_impl.py
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -16,16 +16,17 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
- s_in = x.new_ones([x.shape[0]])
+ s_in = x.new_ones((x.shape[0]))
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
- a_t = alphas[index].item() * s_in
- a_prev = alphas_prev[index].item() * s_in
- sigma_t = sigmas[index].item() * s_in
- sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+ a_t = alphas[index].item() * s_x
+ a_prev = alphas_prev[index].item() * s_x
+ sigma_t = sigmas[index].item() * s_x
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
@@ -47,13 +48,14 @@ def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
- a_t = alphas[index].item() * s_in
- a_prev = alphas_prev[index].item() * s_in
- sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+ a_t = alphas[index].item() * s_x
+ a_prev = alphas_prev[index].item() * s_x
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()