aboutsummaryrefslogtreecommitdiff
path: root/modules/npu_specific.py
blob: d8aebf9c2141e0adddbfbd904eacf92ba357e61c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import importlib
import torch

from modules import shared


def check_for_npu():
    if importlib.util.find_spec("torch_npu") is None:
        return False
    import torch_npu
    torch_npu.npu.set_device(0)

    try:
        # Will raise a RuntimeError if no NPU is found
        _ = torch.npu.device_count()
        return torch.npu.is_available()
    except RuntimeError:
        return False


def get_npu_device_string():
    if shared.cmd_opts.device_id is not None:
        return f"npu:{shared.cmd_opts.device_id}"
    return "npu:0"


def torch_npu_gc():
    # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
    torch.npu.set_device(0)
    with torch.npu.device(get_npu_device_string()):
        torch.npu.empty_cache()


has_npu = check_for_npu()