aboutsummaryrefslogtreecommitdiff
path: root/modules/npu_specific.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/npu_specific.py')
-rw-r--r--modules/npu_specific.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/modules/npu_specific.py b/modules/npu_specific.py
new file mode 100644
index 00000000..94100691
--- /dev/null
+++ b/modules/npu_specific.py
@@ -0,0 +1,31 @@
+import importlib
+import torch
+
+from modules import shared
+
+
+def check_for_npu():
+ if importlib.util.find_spec("torch_npu") is None:
+ return False
+ import torch_npu
+
+ try:
+ # Will raise a RuntimeError if no NPU is found
+ _ = torch_npu.npu.device_count()
+ return torch.npu.is_available()
+ except RuntimeError:
+ return False
+
+
+def get_npu_device_string():
+ if shared.cmd_opts.device_id is not None:
+ return f"npu:{shared.cmd_opts.device_id}"
+ return "npu:0"
+
+
+def torch_npu_gc():
+ with torch.npu.device(get_npu_device_string()):
+ torch.npu.empty_cache()
+
+
+has_npu = check_for_npu()