diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py index a93a245b..e4430e1a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -31,3 +31,20 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") + + +device = get_optimal_device() +device_codeformer = cpu if has_mps else device + + +def randn(seed, shape): + # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. + if device.type == 'mps': + generator = torch.Generator(device=cpu) + generator.manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=cpu).to(device) + return noise + + torch.manual_seed(seed) + return torch.randn(shape, device=device) + |