aboutsummaryrefslogtreecommitdiff
path: root/modules/models/diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/models/diffusion')
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
index 6bb3bb21..bf346ff4 100644
--- a/modules/models/diffusion/uni_pc/sampler.py
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -3,7 +3,8 @@
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
-from modules import shared
+from modules import shared, devices
+
class UniPCSampler(object):
def __init__(self, model, **kwargs):
@@ -16,8 +17,8 @@ class UniPCSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
+ if attr.device != devices.device:
+ attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):