aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-25 20:11:01 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-25 20:11:01 +0300
commit15e89ef0f6f22f823c19592a401b9e4ee477258c (patch)
tree861ad985646cfca1f0bc5ffce68fe5fa12fea509
parent789d47f832a5c921dbbdd0a657dff9bca7f78d94 (diff)
fix for unet hijack breaking the train tab
-rw-r--r--modules/sd_hijack_unet.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 88c94e54..a6ee577c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+ if isinstance(cond, dict):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()