aboutsummaryrefslogtreecommitdiff
path: root/modules/models/diffusion/uni_pc/sampler.py
diff options
context:
space:
mode:
authorspace-nuko <24979496+space-nuko@users.noreply.github.com>2023-02-10 04:47:08 -0800
committerspace-nuko <24979496+space-nuko@users.noreply.github.com>2023-02-10 04:47:08 -0800
commit21880eb9e57b884635a07d2360831b4186afddf4 (patch)
treeecc0969bb4e36b1addb157464b6dae86faefe583 /modules/models/diffusion/uni_pc/sampler.py
parent125319988984987801dc4b4ab1e5ed36e9b211c5 (diff)
Fix logspam and live previews
Diffstat (limited to 'modules/models/diffusion/uni_pc/sampler.py')
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
index 7cccd8a2..219e9862 100644
--- a/modules/models/diffusion/uni_pc/sampler.py
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -19,9 +19,10 @@ class UniPCSampler(object):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
- def set_hooks(self, before, after):
- self.before_sample = before
- self.after_sample = after
+ def set_hooks(self, before_sample, after_sample, after_update):
+ self.before_sample = before_sample
+ self.after_sample = after_sample
+ self.after_update = after_update
@torch.no_grad()
def sample(self,
@@ -50,9 +51,17 @@ class UniPCSampler(object):
):
if conditioning is not None:
if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
@@ -60,6 +69,7 @@ class UniPCSampler(object):
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
+ print(f'Data shape for UniPC sampling is {size}, eta {eta}')
device = self.model.betas.device
if x_T is None:
@@ -79,7 +89,7 @@ class UniPCSampler(object):
guidance_scale=unconditional_guidance_scale,
)
- uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample)
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
return x.to(device), None