aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/extensions.py132
-rw-r--r--modules/scripts.py169
2 files changed, 148 insertions, 153 deletions
diff --git a/modules/extensions.py b/modules/extensions.py
index f3988d02..1899cd52 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
import configparser
-import functools
import os
import threading
import re
@@ -8,7 +9,6 @@ from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
-extensions = []
os.makedirs(extensions_dir, exist_ok=True)
@@ -22,13 +22,56 @@ def active():
return [x for x in extensions if x.enabled]
+class ExtensionMetadata:
+ filename = "metadata.ini"
+ config: configparser.ConfigParser
+ canonical_name: str
+ requires: list
+
+ def __init__(self, path, canonical_name):
+ self.config = configparser.ConfigParser()
+
+ filepath = os.path.join(path, self.filename)
+ if os.path.isfile(filepath):
+ try:
+ self.config.read(filepath)
+ except Exception:
+ errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
+
+ self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
+ self.canonical_name = canonical_name.lower().strip()
+
+ self.requires = self.get_script_requirements("Requires", "Extension")
+
+ def get_script_requirements(self, field, section, extra_section=None):
+ """reads a list of requirements from the config; field is the name of the field in the ini file,
+ like Requires or Before, and section is the name of the [section] in the ini file; additionally,
+ reads more requirements from [extra_section] if specified."""
+
+ x = self.config.get(section, field, fallback='')
+
+ if extra_section:
+ x = x + ', ' + self.config.get(extra_section, field, fallback='')
+
+ return self.parse_list(x.lower())
+
+ def parse_list(self, text):
+ """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
+
+ if not text:
+ return []
+
+ # both "," and " " are accepted as separator
+ return [x for x in re.split(r"[,\s]+", text.strip()) if x]
+
+
class Extension:
lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
+ metadata: ExtensionMetadata
- def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
+ def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name
- self.canonical_name = canonical_name or name.lower()
self.path = path
self.enabled = enabled
self.status = ''
@@ -40,18 +83,8 @@ class Extension:
self.branch = None
self.remote = None
self.have_info_from_repo = False
-
- @functools.cached_property
- def metadata(self):
- if os.path.isfile(os.path.join(self.path, "metadata.ini")):
- try:
- config = configparser.ConfigParser()
- config.read(os.path.join(self.path, "metadata.ini"))
- return config
- except Exception:
- errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
- exc_info=True)
- return None
+ self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
+ self.canonical_name = metadata.canonical_name
def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields}
@@ -162,7 +195,7 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
- extension_dependency_map = {}
+ loaded_extensions = {}
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
@@ -175,55 +208,30 @@ def list_extensions():
continue
canonical_name = extension_dirname
- requires = None
+ metadata = ExtensionMetadata(path, canonical_name)
- if os.path.isfile(os.path.join(path, "metadata.ini")):
- try:
- config = configparser.ConfigParser()
- config.read(os.path.join(path, "metadata.ini"))
- canonical_name = config.get("Extension", "Name", fallback=canonical_name)
- requires = config.get("Extension", "Requires", fallback=None)
- except Exception:
- errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
- f"Will load regardless.", exc_info=True)
+ # check for duplicated canonical names
+ already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
+ if already_loaded_extension is not None:
+ errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
+ continue
- canonical_name = canonical_name.lower().strip()
+ is_builtin = dirname == extensions_builtin_dir
+ extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
+ extensions.append(extension)
+ loaded_extensions[canonical_name] = extension
- # check for duplicated canonical names
- if canonical_name in extension_dependency_map:
- errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
- f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
- f"The current loading extension will be discarded.", exc_info=False)
+ # check for requirements
+ for extension in extensions:
+ for req in extension.metadata.requires:
+ required_extension = loaded_extensions.get(req)
+ if required_extension is None:
+ errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
- # both "," and " " are accepted as separator
- requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
+ if not extension.enabled:
+ errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
+ continue
- extension_dependency_map[canonical_name] = {
- "dirname": extension_dirname,
- "path": path,
- "requires": requires,
- }
- # check for requirements
- for (_, extension_data) in extension_dependency_map.items():
- dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
- requirement_met = True
- for req in requires:
- if req not in extension_dependency_map:
- errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
- f"The current loading extension will be discarded.", exc_info=False)
- requirement_met = False
- break
- dep_dirname = extension_dependency_map[req]['dirname']
- if dep_dirname in shared.opts.disabled_extensions:
- errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
- f"The current loading extension will be discarded.", exc_info=False)
- requirement_met = False
- break
-
- is_builtin = dirname == extensions_builtin_dir
- extension = Extension(name=dirname, path=path,
- enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
- is_builtin=is_builtin)
- extensions.append(extension)
+extensions: list[Extension] = []
diff --git a/modules/scripts.py b/modules/scripts.py
index b1f4504a..b0689a23 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -2,7 +2,6 @@ import os
import re
import sys
import inspect
-from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from dataclasses import dataclass
@@ -312,27 +311,57 @@ scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
+def topological_sort(dependencies):
+ """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
+ Ignores errors relating to missing dependeencies or circular dependencies
+ """
+
+ visited = {}
+ result = []
+
+ def inner(name):
+ visited[name] = True
+
+ for dep in dependencies.get(name, []):
+ if dep in dependencies and dep not in visited:
+ inner(dep)
+
+ result.append(name)
+
+ for depname in dependencies:
+ if depname not in visited:
+ inner(depname)
+
+ return result
+
+
+@dataclass
+class ScriptWithDependencies:
+ script_canonical_name: str
+ file: ScriptFile
+ requires: list
+ load_before: list
+ load_after: list
+
def list_scripts(scriptdirname, extension, *, include_extensions=True):
- scripts_list = []
- script_dependency_map = {}
+ scripts = {}
- # build script dependency map
+ loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
+ loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
+ # build script dependency map
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(root_script_basedir):
for filename in sorted(os.listdir(root_script_basedir)):
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
continue
- script_dependency_map[filename] = {
- "extension": None,
- "extension_dirname": None,
- "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
- "requires": [],
- "load_before": [],
- "load_after": [],
- }
+ if os.path.splitext(filename)[1].lower() != extension:
+ continue
+
+ script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
+ scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
if include_extensions:
for ext in extensions.active():
@@ -341,96 +370,54 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True):
if not os.path.isfile(extension_script.path):
continue
- script_canonical_name = ext.canonical_name + "/" + extension_script.filename
- if ext.is_builtin:
- script_canonical_name = "builtin/" + script_canonical_name
+ script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
relative_path = scriptdirname + "/" + extension_script.filename
- requires = ''
- load_before = ''
- load_after = ''
-
- if ext.metadata is not None:
- requires = ext.metadata.get(relative_path, "Requires", fallback='')
- load_before = ext.metadata.get(relative_path, "Before", fallback='')
- load_after = ext.metadata.get(relative_path, "After", fallback='')
-
- # propagate directory level metadata
- requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='')
- load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='')
- load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='')
-
- requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
- load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
- load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []
-
- script_dependency_map[script_canonical_name] = {
- "extension": ext.canonical_name,
- "extension_dirname": ext.name,
- "script_file": extension_script,
- "requires": requires,
- "load_before": load_before,
- "load_after": load_after,
- }
+ script = ScriptWithDependencies(
+ script_canonical_name=script_canonical_name,
+ file=extension_script,
+ requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
+ load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
+ load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
+ )
- # resolve dependencies
+ scripts[script_canonical_name] = script
+ loaded_extensions_scripts[ext.canonical_name].append(script)
- loaded_extensions = set()
- for ext in extensions.active():
- loaded_extensions.add(ext.canonical_name)
-
- for script_canonical_name, script_data in script_dependency_map.items():
+ for script_canonical_name, script in scripts.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
- for load_before_script in script_data['load_before']:
+ for load_before in script.load_before:
# if this requires an individual script to be loaded before
- if load_before_script in script_dependency_map:
- script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
- elif load_before_script in loaded_extensions:
- for _, script_data2 in script_dependency_map.items():
- if script_data2['extension'] == load_before_script:
- script_data2['load_after'].append(script_canonical_name)
- break
-
- # resolve extension name in load_after lists
- for load_after_script in list(script_data['load_after']):
- if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
- script_data['load_after'].remove(load_after_script)
- for script_canonical_name2, script_data2 in script_dependency_map.items():
- if script_data2['extension'] == load_after_script:
- script_data['load_after'].append(script_canonical_name2)
- break
-
- # build the DAG
- sorter = TopologicalSorter()
- for script_canonical_name, script_data in script_dependency_map.items():
- requirement_met = True
- for required_script in script_data['requires']:
- # if this requires an individual script to be loaded
- if required_script not in script_dependency_map and required_script not in loaded_extensions:
- errors.report(f"Script \"{script_canonical_name}\" "
- f"requires \"{required_script}\" to "
- f"be loaded, but it is not. Skipping.",
- exc_info=False)
- requirement_met = False
- break
- if not requirement_met:
- continue
+ other_script = scripts.get(load_before)
+ if other_script:
+ other_script.load_after.append(script_canonical_name)
- sorter.add(script_canonical_name, *script_data['load_after'])
+ # if this requires an extension
+ other_extension_scripts = loaded_extensions_scripts.get(load_before)
+ if other_extension_scripts:
+ for other_script in other_extension_scripts:
+ other_script.load_after.append(script_canonical_name)
- # sort the scripts
- try:
- ordered_script = sorter.static_order()
- except CycleError:
- errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
- ordered_script = script_dependency_map.keys()
+ # if After mentions an extension, remove it and instead add all of its scripts
+ for load_after in list(script.load_after):
+ if load_after not in scripts and load_after in loaded_extensions_scripts:
+ script.load_after.remove(load_after)
+
+ for other_script in loaded_extensions_scripts.get(load_after, []):
+ script.load_after.append(other_script.script_canonical_name)
+
+ dependencies = {}
+
+ for script_canonical_name, script in scripts.items():
+ for required_script in script.requires:
+ if required_script not in scripts and required_script not in loaded_extensions:
+ errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
- for script_canonical_name in ordered_script:
- script_data = script_dependency_map[script_canonical_name]
- scripts_list.append(script_data['script_file'])
+ dependencies[script_canonical_name] = script.load_after
- scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+ ordered_scripts = topological_sort(dependencies)
+ scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
return scripts_list