aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew McGoogan <mlmcgoogan@gmail.com>2022-11-26 23:25:16 +0000
committerMatthew McGoogan <mlmcgoogan@gmail.com>2022-11-26 23:25:16 +0000
commitc67c40f983997594f76b2312f92c3761e8d83715 (patch)
tree8f040a774bbde068b374d71a6fc1467be0195d14
parentb5050ad2071644f7b4c99660dc66a8a95136102f (diff)
torch.cuda.empty_cache() defaults to cuda:0 device unless explicitly set otherwise first. Updating torch_gc() to use the device set by --device-id if specified to avoid OOM edge cases on multi-GPU systems.
-rw-r--r--modules/devices.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 67165bf6..93d82bbc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -44,8 +44,18 @@ def get_optimal_device():
def torch_gc():
if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ else:
+ cuda_device = "cuda"
+
+ with torch.cuda.device(cuda_device):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
def enable_tf32():