aboutsummaryrefslogtreecommitdiff
path: root/modules/modelloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e647f6fa..e9aa514e 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
-def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
+def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file
if os.path.isdir(full_path):
continue
+ if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
+ continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
@@ -123,6 +125,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +158,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))