aboutsummaryrefslogtreecommitdiff
path: root/modules/memmon.py
diff options
context:
space:
mode:
authorEyeDeck <eyedeck@gmail.com>2022-09-18 05:20:33 -0400
committerEyeDeck <eyedeck@gmail.com>2022-09-18 05:20:33 -0400
commitfabaf4bddb6f5968bffe75e7766f7687813c4d36 (patch)
tree43f0ff2ff6f29168f5ce48316ba69a220d00fb45 /modules/memmon.py
parent7e77938230d4fefb6edccdba0b80b61d8416673e (diff)
Add some error handling for VRAM monitor
Diffstat (limited to 'modules/memmon.py')
-rw-r--r--modules/memmon.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/modules/memmon.py b/modules/memmon.py
index f2cac841..9fb9b687 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread):
self.run_flag = threading.Event()
self.data = defaultdict(int)
+ try:
+ torch.cuda.mem_get_info()
+ torch.cuda.memory_stats(self.device)
+ except Exception as e: # AMD or whatever
+ print(f"Warning: caught exception '{e}', memory monitor disabled")
+ self.disabled = True
+
def run(self):
if self.disabled:
return
@@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread):
self.run_flag.set()
def read(self):
- free, total = torch.cuda.mem_get_info()
- self.data["total"] = total
-
- torch_stats = torch.cuda.memory_stats(self.device)
- self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
- self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
- self.data["system_peak"] = total - self.data["min_free"]
+ if not self.disabled:
+ free, total = torch.cuda.mem_get_info()
+ self.data["total"] = total
+
+ torch_stats = torch.cuda.memory_stats(self.device)
+ self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+ self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
+ self.data["system_peak"] = total - self.data["min_free"]
return self.data