aboutsummaryrefslogtreecommitdiff
path: root/modules/codeformer_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/codeformer_model.py')
-rw-r--r--modules/codeformer_model.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 336f007d..8fbdea24 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -47,13 +47,11 @@ def setup_codeformer():
def __init__(self):
self.net = None
self.face_helper = None
- if shared.device.type == 'mps': # CodeFormer currently does not support mps backend
- shared.device_codeformer = torch.device('cpu')
def create_models(self):
if self.net is not None and self.face_helper is not None:
- self.net.to(shared.device_codeformer)
+ self.net.to(devices.device_codeformer)
return self.net, self.face_helper
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
@@ -66,7 +64,7 @@ def setup_codeformer():
self.net = net
self.face_helper = face_helper
- self.net.to(shared.device_codeformer)
+ self.net.to(devices.device_codeformer)
return net, face_helper