aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorIvan <ivan.demian2009@gmail.com>2023-01-25 22:34:19 +0500
committerGitHub <noreply@github.com>2023-01-25 22:34:19 +0500
commitdc0f05c57cb588b918102c7e64dcfe2b06cc0e90 (patch)
tree192a43b4375b89283016777aef3eb0a28f90cae5 /modules
parent57096823fadbc18b33d9b89d2d3a02d5ebba29f4 (diff)
parent15e89ef0f6f22f823c19592a401b9e4ee477258c (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack_unet.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 88c94e54..a6ee577c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
- for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
+ if isinstance(cond, dict):
+ for y in cond.keys():
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()