aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/ScuNET
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-08 15:10:10 +0300
committerGitHub <noreply@github.com>2023-07-08 15:10:10 +0300
commitec9bbda3da846b4aa2f03b1e0f0952ffbc61f4f6 (patch)
tree5ae676b6ac7242ad082ab7421cf8ae9d50d52818 /extensions-builtin/ScuNET
parentc4c63dd5e4760c56405cef2e71abc5c3604c4578 (diff)
parent18256c5f0174126cb103afece2b39b6b831e034a (diff)
Merge branch 'dev' into img2img-batch-png-info
Diffstat (limited to 'extensions-builtin/ScuNET')
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py33
1 files changed, 14 insertions, 19 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 45d9297b..ffef26b2 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,17 +1,15 @@
-import os.path
import sys
-import traceback
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
-from basicsr.utils.download_util import load_file_from_url
-
import modules.upscaler
-from modules import devices, modelloader, script_callbacks
-from scunet_model_arch import SCUNet as net
+from modules import devices, modelloader, script_callbacks, errors
+from scunet_model_arch import SCUNet
+
+from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = []
add_model2 = True
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -38,8 +36,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
@@ -90,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
+ try:
+ model = self.load_model(selected_file)
+ except Exception as e:
+ print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
@@ -120,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
- print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
- return None
-
- model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+ model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():