aboutsummaryrefslogtreecommitdiff
path: root/modules/models/diffusion/uni_pc/uni_pc.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/models/diffusion/uni_pc/uni_pc.py')
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index df63d1bc..eb5f4e76 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F
import math
+from tqdm.auto import trange
class NoiseScheduleVP:
@@ -719,7 +720,7 @@ class UniPC:
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = (
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
)
if x_t is None:
@@ -750,7 +751,7 @@ class UniPC:
if method == 'multistep':
assert steps >= order, "UniPC order must be < sampling steps"
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
- print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
+ #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
@@ -766,7 +767,7 @@ class UniPC:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
- for step in range(order, steps + 1):
+ for step in trange(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)