aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--javascript/edit-order.js2
-rw-r--r--modules/devices.py5
-rw-r--r--modules/mac_specific.py18
3 files changed, 22 insertions, 3 deletions
diff --git a/javascript/edit-order.js b/javascript/edit-order.js
index e6e73937..ed4ef9ac 100644
--- a/javascript/edit-order.js
+++ b/javascript/edit-order.js
@@ -6,11 +6,11 @@ function keyupEditOrder(event) {
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (!event.altKey) return;
- event.preventDefault();
let isLeft = event.key == "ArrowLeft";
let isRight = event.key == "ArrowRight";
if (!isLeft && !isRight) return;
+ event.preventDefault();
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
diff --git a/modules/devices.py b/modules/devices.py
index c5ad950f..57e51da3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
- elif has_mps() and hasattr(torch.mps, 'empty_cache'):
- torch.mps.empty_cache()
+
+ if has_mps():
+ mac_specific.torch_mps_gc()
def enable_tf32():
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 735847f5..9ceb43ba 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,8 +1,12 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger(__name__)
+
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,23 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
has_mps = check_for_mps()
+def torch_mps_gc() -> None:
+ try:
+ from modules.shared import state
+ if state.current_latent is not None:
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
+ return
+ from torch.mps import empty_cache
+ empty_cache()
+ except Exception:
+ log.warning("MPS garbage collection failed", exc_info=True)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':