aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py20
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py32
-rw-r--r--modules/sd_samplers_compvis.py20
3 files changed, 41 insertions, 31 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
index 7cccd8a2..219e9862 100644
--- a/modules/models/diffusion/uni_pc/sampler.py
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -19,9 +19,10 @@ class UniPCSampler(object):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
- def set_hooks(self, before, after):
- self.before_sample = before
- self.after_sample = after
+ def set_hooks(self, before_sample, after_sample, after_update):
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
@torch.no_grad()
def sample(self,
@@ -50,9 +51,17 @@ class UniPCSampler(object):
):
if conditioning is not None:
if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
@@ -60,6 +69,7 @@ class UniPCSampler(object):
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
+ print(f'Data shape for UniPC sampling is {size}, eta {eta}')
device = self.model.betas.device
if x_T is None:
@@ -79,7 +89,7 @@ class UniPCSampler(object):
guidance_scale=unconditional_guidance_scale,
)
- uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample)
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
return x.to(device), None
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index ec6b37da..31ee81a6 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -378,7 +378,8 @@ class UniPC:
condition=None,
unconditional_condition=None,
before_sample=None,
- after_sample=None
+ after_sample=None,
+ after_update=None
):
"""Construct a UniPC.
@@ -394,6 +395,7 @@ class UniPC:
self.unconditional_condition = unconditional_condition
self.before_sample = before_sample
self.after_sample = after_sample
+ self.after_update = after_update
def dynamic_thresholding_fn(self, x0, t=None):
"""
@@ -434,15 +436,6 @@ class UniPC:
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
- from pprint import pp
- print("X:")
- pp(x)
- print("sigma_t:")
- pp(sigma_t)
- print("noise:")
- pp(noise)
- print("alpha_t:")
- pp(alpha_t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
@@ -524,7 +517,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
- print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
+ #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule
assert order <= len(model_prev_list)
@@ -568,7 +561,7 @@ class UniPC:
A_p = C_inv_p
if use_corrector:
- print('using corrector')
+ #print('using corrector')
C_inv = torch.linalg.inv(C)
A_c = C_inv
@@ -627,7 +620,7 @@ class UniPC:
return x_t, model_t
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
- print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
+ #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
ns = self.noise_schedule
assert order <= len(model_prev_list)
dims = x.dim()
@@ -695,7 +688,7 @@ class UniPC:
D1s = None
if use_corrector:
- print('using corrector')
+ #print('using corrector')
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], device=b.device)
@@ -755,8 +748,9 @@ class UniPC:
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == 'multistep':
- assert steps >= order
+ assert steps >= order, "UniPC order must be < sampling steps"
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
@@ -768,6 +762,8 @@ class UniPC:
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None:
model_x = self.model_fn(x, vec_t)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in range(order, steps + 1):
@@ -776,13 +772,15 @@ class UniPC:
step_order = min(order, steps + 1 - step)
else:
step_order = order
- print('this step order:', step_order)
+ #print('this step order:', step_order)
if step == steps:
- print('do not run corrector at the last step')
+ #print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index 86fa1c5b..946079ae 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:
return x, ts, cond, unconditional_conditioning
- def after_sample(self, x, ts, cond, uncond, res):
- if self.is_unipc:
- # unipc model_fn returns (pred_x0)
- # p_sample_ddim returns (x_prev, pred_x0)
- res = (None, res[0])
-
+ def update_step(self, last_latent):
if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
+ self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else:
- self.last_latent = res[1]
+ self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
@@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step
shared.total_tqdm.update()
+ def after_sample(self, x, ts, cond, uncond, res):
+ if not self.is_unipc:
+ self.update_step(res[1])
+
return x, ts, cond, uncond, res
+ def unipc_after_update(self, x, model_x):
+ self.update_step(x)
+
def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0:
@@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc:
- self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
+ self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None