aboutsummaryrefslogtreecommitdiff
path: root/modules/scunet_model.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-04 15:20:10 +0300
committerGitHub <noreply@github.com>2022-10-04 15:20:10 +0300
commitbc4d457de82f76f8ab9f2bedf933c06deb5d5ba9 (patch)
tree969e0e595bd36987ae9de9ae302085ef555bba15 /modules/scunet_model.py
parentd5bba20a58f43a9f984bb67b4e17f48661f6b818 (diff)
parente9e2a7ec9ac704f133f586eb34176e388c93c87c (diff)
Merge pull request #1616 from brkirch/cpu-cmdline-opt
Add --use-cpu command line option
Diffstat (limited to 'modules/scunet_model.py')
-rw-r--r--modules/scunet_model.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 7987ac14..fb64b740 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net
@@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None:
return img
- device = shared.device
+ device = devices.device_scunet
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(device)
img = img.to(device)
with torch.no_grad():
@@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
- device = shared.device
+ device = devices.device_scunet
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)