aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/sd_hijack_unet.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 88c94e54..a6ee577c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+ if isinstance(cond, dict):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()