aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py114
1 files changed, 109 insertions, 5 deletions
diff --git a/modules/devices.py b/modules/devices.py
index ea1f712f..28c0c54d 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors, shared
+from modules import errors, shared, npu_specific
if sys.platform == "darwin":
from modules import mac_specific
@@ -23,6 +23,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -40,6 +57,9 @@ def get_optimal_device_name():
if has_xpu():
return xpu_specific.get_xpu_device_string()
+ if npu_specific.has_npu:
+ return npu_specific.get_npu_device_string()
+
return "cpu"
@@ -67,14 +87,23 @@ def torch_gc():
if has_xpu():
xpu_specific.torch_xpu_gc()
+ if npu_specific.has_npu:
+ torch_npu_set_device()
+ npu_specific.torch_npu_gc()
+
+
+def torch_npu_set_device():
+ # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+ if npu_specific.has_npu:
+ torch.npu.set_device(0)
+
def enable_tf32():
if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -84,6 +113,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -92,6 +122,7 @@ device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
+dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -104,15 +135,89 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+
+def manual_cast_forward(target_dtype):
+ def forward_wrapper(self, *args, **kwargs):
+ if any(
+ isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
+ for arg in args
+ ):
+ args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+
+ org_dtype = target_dtype
+ for param in self.parameters():
+ if param.dtype != target_dtype:
+ org_dtype = param.dtype
+ break
+
+ if org_dtype != target_dtype:
+ self.to(target_dtype)
+ result = self.org_forward(*args, **kwargs)
+ if org_dtype != target_dtype:
+ self.to(org_dtype)
+
+ if target_dtype != dtype_inference:
+ if isinstance(result, tuple):
+ result = tuple(
+ i.to(dtype_inference)
+ if isinstance(i, torch.Tensor)
+ else i
+ for i in result
+ )
+ elif isinstance(result, torch.Tensor):
+ result = result.to(dtype_inference)
+ return result
+ return forward_wrapper
+
+
+@contextlib.contextmanager
+def manual_cast(target_dtype):
+ applied = False
+ for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ continue
+ applied = True
+ org_forward = module_type.forward
+ if module_type == torch.nn.MultiheadAttention:
+ module_type.forward = manual_cast_forward(torch.float32)
+ else:
+ module_type.forward = manual_cast_forward(target_dtype)
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ if applied:
+ for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ module_type.forward = module_type.org_forward
+ delattr(module_type, "org_forward")
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
- if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and dtype_inference == torch.float32:
+ return manual_cast(dtype)
+
+ if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
+ if has_xpu() or has_mps() or cuda_no_autocast():
+ return manual_cast(dtype)
+
return torch.autocast("cuda")
@@ -164,4 +269,3 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
-