aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-12 23:53:26 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-07-12 23:53:26 +0300
commit60397a7800d7e01d9a75e0179e3d2c10aa0002a9 (patch)
tree92f6b8eec31e3811e29b377418ec6a5dcb613320
parentda464a3fb39ecc6ea7b22fe87271194480d8501c (diff)
parente5ca9877781bf2ce45edfb9c46ba532668c50de9 (diff)
Merge branch 'dev' into sdxl
-rw-r--r--javascript/edit-order.js2
-rw-r--r--modules/devices.py5
-rw-r--r--modules/mac_specific.py18
3 files changed, 22 insertions, 3 deletions
diff --git a/javascript/edit-order.js b/javascript/edit-order.js
index e6e73937..ed4ef9ac 100644
--- a/javascript/edit-order.js
+++ b/javascript/edit-order.js
@@ -6,11 +6,11 @@ function keyupEditOrder(event) {
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (!event.altKey) return;
- event.preventDefault();
let isLeft = event.key == "ArrowLeft";
let isRight = event.key == "ArrowRight";
if (!isLeft && !isRight) return;
+ event.preventDefault();
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
diff --git a/modules/devices.py b/modules/devices.py
index c5ad950f..57e51da3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
- elif has_mps() and hasattr(torch.mps, 'empty_cache'):
- torch.mps.empty_cache()
+
+ if has_mps():
+ mac_specific.torch_mps_gc()
def enable_tf32():
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 735847f5..9ceb43ba 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,8 +1,12 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger(__name__)
+
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,23 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
has_mps = check_for_mps()
+def torch_mps_gc() -> None:
+ try:
+ from modules.shared import state
+ if state.current_latent is not None:
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
+ return
+ from torch.mps import empty_cache
+ empty_cache()
+ except Exception:
+ log.warning("MPS garbage collection failed", exc_info=True)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':