aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_timesteps.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-14 08:49:02 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-14 08:49:02 +0300
commitaeb76ef174bc8a1904b25ca0b0b5009395f07d96 (patch)
treef10bee54f34d10b6b0d39c311e2ec252c6f8ec8d /modules/sd_samplers_timesteps.py
parent007ecfbb29771aa7cdcf0263ab1811bc75fa5446 (diff)
repair DDIM/PLMS/UniPC batches
Diffstat (limited to 'modules/sd_samplers_timesteps.py')
-rw-r--r--modules/sd_samplers_timesteps.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index 16572c7e..6aed2974 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -51,10 +51,9 @@ class CFGDenoiserTimesteps(CFGDenoiser):
self.alphas = shared.sd_model.alphas_cumprod
def get_pred_x0(self, x_in, x_out, sigma):
- ts = int(sigma.item())
+ ts = sigma.to(dtype=int)
- s_in = x_in.new_ones([x_in.shape[0]])
- a_t = self.alphas[ts].item() * s_in
+ a_t = self.alphas[ts][:, None, None, None]
sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()