aboutsummaryrefslogtreecommitdiff
path: root/modules/modelloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py66
1 files changed, 22 insertions, 44 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e351d808..2a479bcb 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,10 +1,8 @@
-import glob
import os
import shutil
import importlib
from urllib.parse import urlparse
-from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@@ -23,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
"""
output = []
- if ext_filter is None:
- ext_filter = []
-
try:
places = []
@@ -40,25 +35,18 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
places.append(model_path)
for place in places:
- if os.path.exists(place):
- for file in glob.iglob(place + '**/**', recursive=True):
- full_path = file
- if os.path.isdir(full_path):
- continue
- if os.path.islink(full_path) and not os.path.exists(full_path):
- print(f"Skipping broken symlink: {full_path}")
- continue
- if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
- continue
- if len(ext_filter) != 0:
- model_name, extension = os.path.splitext(file)
- if extension not in ext_filter:
- continue
- if file not in output:
- output.append(full_path)
+ for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
+ if os.path.islink(full_path) and not os.path.exists(full_path):
+ print(f"Skipping broken symlink: {full_path}")
+ continue
+ if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
+ continue
+ if full_path not in output:
+ output.append(full_path)
if model_url is not None and len(output) == 0:
if download_name is not None:
+ from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
@@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
- except:
+ except Exception:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
- except:
+ except Exception:
pass
-builtin_upscaler_classes = []
-forbidden_upscaler_classes = set()
-
-
-def list_builtin_upscalers():
- load_upscalers()
-
- builtin_upscaler_classes.clear()
- builtin_upscaler_classes.extend(Upscaler.__subclasses__())
-
-
-def forbid_loaded_nonbuiltin_upscalers():
- for cls in Upscaler.__subclasses__():
- if cls not in builtin_upscaler_classes:
- forbidden_upscaler_classes.add(cls)
-
-
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -155,15 +126,22 @@ def load_upscalers():
full_model = f"modules.{model_name}_model"
try:
importlib.import_module(full_model)
- except:
+ except Exception:
pass
datas = []
commandline_options = vars(shared.cmd_opts)
- for cls in Upscaler.__subclasses__():
- if cls in forbidden_upscaler_classes:
- continue
+ # some of upscaler classes will not go away after reloading their modules, and we'll end
+ # up with two copies of those classes. The newest copy will always be the last in the list,
+ # so we go from end to beginning and ignore duplicates
+ used_classes = {}
+ for cls in reversed(Upscaler.__subclasses__()):
+ classname = str(cls)
+ if classname not in used_classes:
+ used_classes[classname] = cls
+
+ for cls in reversed(used_classes.values()):
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))