From 7aab389d6fc8ad08729071b1ed9d4de64c4e44db Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 14 Apr 2023 02:22:48 -0400 Subject: Fix for Unet NaNs --- modules/sd_hijack_optimizations.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 372555ff..f10865cd 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + if q.device.type == 'mps': + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() + dtype = q.dtype if shared.opts.upcast_attn: q, k = q.float(), k.float() -- cgit v1.2.1