aboutsummaryrefslogtreecommitdiff
path: root/modules/mac_specific.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-13 15:21:39 +0300
committerGitHub <noreply@github.com>2023-07-13 15:21:39 +0300
commitb7c5b30f14aadffadd2d35cb3ecb3e91af00581d (patch)
tree0e51f517bb6ac010c0e3dc5937d112656ec9ee9a /modules/mac_specific.py
parent14501f56aaf3c97fb2c38633350dc747b9651f43 (diff)
parent262ec8ecdaf10d8fe49d0227e24bd3a1459e87b5 (diff)
Merge branch 'dev' into master
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r--modules/mac_specific.py39
1 files changed, 31 insertions, 8 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index d74c6b95..9ceb43ba 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,20 +1,43 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger(__name__)
+
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+ else:
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
+has_mps = check_for_mps()
+
+
+def torch_mps_gc() -> None:
try:
- torch.zeros(1).to(torch.device("mps"))
- return True
+ from modules.shared import state
+ if state.current_latent is not None:
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
+ return
+ from torch.mps import empty_cache
+ empty_cache()
except Exception:
- return False
-has_mps = check_for_mps()
+ log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784