aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorC43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com>2022-09-15 13:31:31 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2022-09-15 14:14:27 +0300
commit7ec6282ec2540cfd1c4cf3e2ec89788b7296f4af (patch)
treed3c9d3677744fd975580eb1b04648c82b173dd41
parentb28cf84c3632df4a6d4c110f7c25d68445b64427 (diff)
pass dtype to torch.zeros as well
-rw-r--r--modules/sd_hijack.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index ec7d14cb..c05ba3b1 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -57,7 +57,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device. dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']