aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-07-10 21:18:34 +0300
committerAarni Koskela <akx@iki.fi>2023-07-11 12:51:05 +0300
commitb85fc7187d953828340d4e3af34af46d9fc70b9e (patch)
tree5a706bd757e03227c3cd1ae1c5a026eae65107ab
parent7b833291b3ef4690ef158ee3415c2e93948acb2d (diff)
Fix MPS cache cleanup
Importing torch does not import torch.mps so the call failed.
-rw-r--r--modules/devices.py5
-rw-r--r--modules/mac_specific.py14
2 files changed, 17 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index c5ad950f..57e51da3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
- elif has_mps() and hasattr(torch.mps, 'empty_cache'):
- torch.mps.empty_cache()
+
+ if has_mps():
+ mac_specific.torch_mps_gc()
def enable_tf32():
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 735847f5..2c2f15ca 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,8 +1,12 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger()
+
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
has_mps = check_for_mps()
+def torch_mps_gc() -> None:
+ try:
+ from torch.mps import empty_cache
+ empty_cache()
+ except Exception:
+ log.warning("MPS garbage collection failed", exc_info=True)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':