From a97147bc8a43ade7c18bb755f0cfac111fc1a619 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:10:02 +0100 Subject: Add support for DAT upscaler models --- modules/dat_model.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 modules/dat_model.py (limited to 'modules/dat_model.py') diff --git a/modules/dat_model.py b/modules/dat_model.py new file mode 100644 index 00000000..8637351c --- /dev/null +++ b/modules/dat_model.py @@ -0,0 +1,81 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + +from icecream import ic + +class UpscalerDAT(Upscaler): + def __init__(self, user_path): + self.name = "DAT" + self.user_path = user_path + self.scalers = [] + super().__init__() + + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scaler_data = UpscalerData(name, file, upscaler=self, scale=None) + self.scalers.append(scaler_data) + + for model in get_dat_models(self): + if model.name in opts.dat_enabled_models: + self.scalers.append(model) + + def do_upscale(self, img, selected_model): + try: + info = self.load_model(selected_model) + except Exception as e: + errors.report(f"Unable to load DAT model {path}", exc_info=True) + return img + + model_descriptor = modelloader.load_spandrel_model( + info.local_data_path, + device=self.device, + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="DAT", + ) + return upscale_with_model( + model_descriptor, + img, + tile_size=opts.DAT_tile, + tile_overlap=opts.DAT_tile_overlap, + ) + + def load_model(self, path): + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") + + +def get_dat_models(scaler): + return [ + UpscalerData( + name="DAT x2", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", + scale=2, + upscaler=scaler, + ), + UpscalerData( + name="DAT x3", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", + scale=3, + upscaler=scaler, + ), + UpscalerData( + name="DAT x4", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", + scale=4, + upscaler=scaler, + ), + ] -- cgit v1.2.1 From 2e7efe47b6c78a59f9f3d64c1d30f54751a31a56 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:39:14 +0100 Subject: Minor cleanup --- modules/dat_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'modules/dat_model.py') diff --git a/modules/dat_model.py b/modules/dat_model.py index 8637351c..495d5f49 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -1,12 +1,10 @@ import os -import sys -from modules import modelloader, devices +from modules import modelloader, errors from modules.shared import cmd_opts, opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -from icecream import ic class UpscalerDAT(Upscaler): def __init__(self, user_path): @@ -24,13 +22,13 @@ class UpscalerDAT(Upscaler): if model.name in opts.dat_enabled_models: self.scalers.append(model) - def do_upscale(self, img, selected_model): + def do_upscale(self, img, path): try: - info = self.load_model(selected_model) - except Exception as e: + info = self.load_model(path) + except Exception: errors.report(f"Unable to load DAT model {path}", exc_info=True) return img - + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, -- cgit v1.2.1