aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py42
1 files changed, 12 insertions, 30 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..c01f0602 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, shared
if sys.platform == "darwin":
from modules import mac_specific
@@ -17,8 +17,6 @@ def has_mps() -> bool:
def get_cuda_device_string():
- from modules import shared
-
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -40,8 +38,6 @@ def get_optimal_device():
def get_device_for(task):
- from modules import shared
-
if task in shared.cmd_opts.use_cpu:
return cpu
@@ -71,14 +67,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -90,26 +89,10 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
-def randn(seed, shape):
- from modules.shared import opts
-
- torch.manual_seed(seed)
- if opts.randn_source == "CPU" or device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
- return torch.randn(shape, device=device)
-
-
-def randn_without_seed(shape):
- from modules.shared import opts
-
- if opts.randn_source == "CPU" or device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
- return torch.randn(shape, device=device)
+nv_rng = None
def autocast(disable=False):
- from modules import shared
-
if disable:
return contextlib.nullcontext()
@@ -128,8 +111,6 @@ class NansException(Exception):
def test_for_nans(x, where):
- from modules import shared
-
if shared.cmd_opts.disable_nan_check:
return
@@ -169,3 +150,4 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
+