aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/bsrgan_model.py6
-rw-r--r--modules/devices.py2
-rw-r--r--modules/shared.py6
3 files changed, 7 insertions, 7 deletions
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
index e62c6657..3bd80791 100644
--- a/modules/bsrgan_model.py
+++ b/modules/bsrgan_model.py
@@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
@@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_bsrgan)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/devices.py b/modules/devices.py
index b5a0cd29..b7899632 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -32,7 +32,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_gfpgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
def randn(seed, shape):
diff --git a/modules/shared.py b/modules/shared.py
index 7899ab8d..95b98a06 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -46,7 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[])
+parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -65,8 +65,8 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print
cmd_opts = parser.parse_args()
-devices.device, devices.device_gfpgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
+devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
device = devices.device