aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py81
1 files changed, 2 insertions, 79 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 00a00b18..ce59dc53 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors, rng_philox
+from modules import errors
if sys.platform == "darwin":
from modules import mac_specific
@@ -96,84 +96,6 @@ def cond_cast_float(input):
nv_rng = None
-def randn(seed, shape):
- """Generate a tensor with random numbers from a normal distribution using seed.
-
- Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
-
- from modules.shared import opts
-
- manual_seed(seed)
-
- if opts.randn_source == "NV":
- return torch.asarray(nv_rng.randn(shape), device=device)
-
- if opts.randn_source == "CPU" or device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
-
- return torch.randn(shape, device=device)
-
-
-def randn_local(seed, shape):
- """Generate a tensor with random numbers from a normal distribution using seed.
-
- Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
-
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- rng = rng_philox.Generator(seed)
- return torch.asarray(rng.randn(shape), device=device)
-
- local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
- local_generator = torch.Generator(local_device).manual_seed(int(seed))
- return torch.randn(shape, device=local_device, generator=local_generator).to(device)
-
-
-def randn_like(x):
- """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
-
- Use either randn() or manual_seed() to initialize the generator."""
-
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
-
- if opts.randn_source == "CPU" or x.device.type == 'mps':
- return torch.randn_like(x, device=cpu).to(x.device)
-
- return torch.randn_like(x)
-
-
-def randn_without_seed(shape):
- """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
-
- Use either randn() or manual_seed() to initialize the generator."""
-
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- return torch.asarray(nv_rng.randn(shape), device=device)
-
- if opts.randn_source == "CPU" or device.type == 'mps':
- return torch.randn(shape, device=cpu).to(device)
-
- return torch.randn(shape, device=device)
-
-
-def manual_seed(seed):
- """Set up a global random number generator using the specified seed."""
- from modules.shared import opts
-
- if opts.randn_source == "NV":
- global nv_rng
- nv_rng = rng_philox.Generator(seed)
- return
-
- torch.manual_seed(seed)
-
-
def autocast(disable=False):
from modules import shared
@@ -236,3 +158,4 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
+