aboutsummaryrefslogtreecommitdiff
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-19 07:59:39 +0300
committerGitHub <noreply@github.com>2023-07-19 07:59:39 +0300
commit0a334b447ff0c41519bb9e280050736913ad9cf8 (patch)
treee27963f76b7357ff0cb7b2c3fdcb720ab64f0e50 /modules/esrgan_model.py
parent6094310704f4b3853bfa5d05d9c1ace58b2deee7 (diff)
parentc2b975485708791b29d44d79ee1a48d3abd838b7 (diff)
Merge branch 'dev' into allow-no-venv-install
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py23
1 files changed, 10 insertions, 13 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2fced999..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,15 +1,13 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)