aboutsummaryrefslogtreecommitdiff
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index ea91abfe..46ad0da3 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -5,10 +5,8 @@ import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
-import modules.esrgam_model_arch as arch
-from modules import shared, modelloader, images
-from modules.devices import has_mps
-from modules.paths import models_path
+import modules.esrgan_model_arch as arch
+from modules import shared, modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -73,11 +71,10 @@ def fix_model_layers(crt_model, pretrained_net):
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
- self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
- self.model_name = "ESRGAN 4x"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
+ self.model_name = "ESRGAN_4x"
self.scalers = []
self.user_path = dirname
- self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
@@ -97,7 +94,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
@@ -112,7 +109,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
@@ -127,7 +124,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()