aboutsummaryrefslogtreecommitdiff
path: root/embeddings
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-12-17 03:21:19 -0500
committerbrkirch <brkirch@users.noreply.github.com>2022-12-17 04:22:58 -0500
commit16b4509fa60ec03102b2452b41799dafccd35970 (patch)
tree37efe0fbd67c70902a5d39f8d5249b9ef3e89ed6 /embeddings
parent685f9631b56ff8bd43bce24ff5ce0f9a0e9af490 (diff)
Add numpy fix for MPS on PyTorch 1.12.1
When saving training results with torch.save(), an exception is thrown: "RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead." So for MPS, check if Tensor.requires_grad and detach() if necessary.
Diffstat (limited to 'embeddings')
0 files changed, 0 insertions, 0 deletions