aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/ScuNET/scripts/scunet_model.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-25 14:43:51 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 16:24:01 +0200
commitb0f59342346b1c8b405f97c0e0bb01c6ae05c601 (patch)
tree8f77ec512bf8c3352d03898cf9bf1c26df02c1a0 /extensions-builtin/ScuNET/scripts/scunet_model.py
parente472383acbb9e07dca311abe5fb16ee2675e410a (diff)
Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR)
Diffstat (limited to 'extensions-builtin/ScuNET/scripts/scunet_model.py')
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py13
1 files changed, 2 insertions, 11 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 167d2f64..18cf8e1a 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -7,9 +7,7 @@ from tqdm import tqdm
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet
-from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
device = devices.get_device_for('scunet')
if path.startswith("http"):
# TODO: this doesn't use `path` at all?
- filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
+ filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for _, v in model.named_parameters():
- v.requires_grad = False
- model = model.to(device)
-
- return model
+ return modelloader.load_spandrel_model(filename, device=device)
def on_ui_settings():