aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 620ed1a6..c5ad950f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -49,10 +49,13 @@ def get_device_for(task):
def torch_gc():
+
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
+ elif has_mps() and hasattr(torch.mps, 'empty_cache'):
+ torch.mps.empty_cache()
def enable_tf32():