aboutsummaryrefslogtreecommitdiff
path: root/modules/npu_specific.py
blob: 9410069110f3cfb9aac3113ceaa141d5c7408e13 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import importlib
import torch

from modules import shared


def check_for_npu():
    if importlib.util.find_spec("torch_npu") is None:
        return False
    import torch_npu

    try:
        # Will raise a RuntimeError if no NPU is found
        _ = torch_npu.npu.device_count()
        return torch.npu.is_available()
    except RuntimeError:
        return False


def get_npu_device_string():
    if shared.cmd_opts.device_id is not None:
        return f"npu:{shared.cmd_opts.device_id}"
    return "npu:0"


def torch_npu_gc():
    with torch.npu.device(get_npu_device_string()):
        torch.npu.empty_cache()


has_npu = check_for_npu()