diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 109 |
1 files changed, 81 insertions, 28 deletions
diff --git a/modules/devices.py b/modules/devices.py index 57e51da3..ff279ac5 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,11 +3,19 @@ import contextlib from functools import lru_cache import torch -from modules import errors +from modules import errors, shared +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -16,9 +24,24 @@ def has_mps() -> bool: return mac_specific.has_mps -def get_cuda_device_string(): - from modules import shared +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + +def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -32,6 +55,9 @@ def get_optimal_device_name(): if has_mps(): return "mps" + if has_xpu(): + return xpu_specific.get_xpu_device_string() + return "cpu" @@ -40,9 +66,7 @@ def get_optimal_device(): def get_device_for(task): - from modules import shared - - if task in shared.cmd_opts.use_cpu: + if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu: return cpu return get_optimal_device() @@ -58,27 +82,34 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - errors.run(enable_tf32, "Enabling TF32") -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 -dtype_vae = torch.float16 -dtype_unet = torch.float16 +cpu: torch.device = torch.device("cpu") +fp8: bool = False +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None +dtype: torch.dtype = torch.float16 +dtype_vae: torch.dtype = torch.float16 +dtype_unet: torch.dtype = torch.float16 unet_needs_upcast = False @@ -90,29 +121,52 @@ def cond_cast_float(input): return input.float() if unet_needs_upcast else input -def randn(seed, shape): - from modules.shared import opts +nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] - torch.manual_seed(seed) - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - return torch.randn(shape, device=device) +def manual_cast_forward(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result -def randn_without_seed(shape): - from modules.shared import opts - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - return torch.randn(shape, device=device) +@contextlib.contextmanager +def manual_cast(): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward def autocast(disable=False): - from modules import shared - if disable: return contextlib.nullcontext() + if fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_cast() + + if has_mps() and shared.cmd_opts.precision != "full": + return manual_cast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() @@ -128,8 +182,6 @@ class NansException(Exception): def test_for_nans(x, where): - from modules import shared - if shared.cmd_opts.disable_nan_check: return @@ -169,3 +221,4 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + |