aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index c5ad950f..57e51da3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
- elif has_mps() and hasattr(torch.mps, 'empty_cache'):
- torch.mps.empty_cache()
+
+ if has_mps():
+ mac_specific.torch_mps_gc()
def enable_tf32():