aboutsummaryrefslogtreecommitdiff
path: root/modules/models/diffusion
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-03-11 11:56:05 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-03-11 11:56:05 +0300
commitf261a4a53c153c630a506bc5282e9955c36b3ef2 (patch)
treecc0de680197875662299ce6644e004134c08df14 /modules/models/diffusion
parenta11ce2b96cc933ebb9e10d46603a89457ddcb9df (diff)
use selected device instead of always cuda for UniPC sampler
Diffstat (limited to 'modules/models/diffusion')
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
index 6bb3bb21..bf346ff4 100644
--- a/modules/models/diffusion/uni_pc/sampler.py
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -3,7 +3,8 @@
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
-from modules import shared
+from modules import shared, devices
+
class UniPCSampler(object):
def __init__(self, model, **kwargs):
@@ -16,8 +17,8 @@ class UniPCSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
+ if attr.device != devices.device:
+ attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):