aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py24
1 files changed, 18 insertions, 6 deletions
diff --git a/webui.py b/webui.py
index 32561877..aaec79fd 100644
--- a/webui.py
+++ b/webui.py
@@ -183,13 +183,16 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler)
-def setup_cors(app):
+def setup_middleware(app):
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
def create_api(app):
@@ -213,8 +216,7 @@ def api_only():
initialize()
app = FastAPI()
- setup_cors(app)
- app.add_middleware(GZipMiddleware, minimum_size=1000)
+ setup_middleware(app)
api = create_api(app)
modules.script_callbacks.app_started_callback(None, app)
@@ -271,9 +273,7 @@ def webui():
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
- setup_cors(app)
-
- app.add_middleware(GZipMiddleware, minimum_size=1000)
+ setup_middleware(app)
modules.progress.setup_progress_api(app)
@@ -290,24 +290,35 @@ def webui():
wait_on_server(shared.demo)
print('Restarting UI...')
+ startup_timer.reset()
+
sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
+ startup_timer.record("list extensions")
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
+ startup_timer.record("load scripts")
+
modules.script_callbacks.model_loaded_callback(shared.sd_model)
+ startup_timer.record("model loaded callback")
+
modelloader.load_upscalers()
+ startup_timer.record("load upscalers")
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
+ startup_timer.record("reload script modules")
modules.sd_models.list_models()
+ startup_timer.record("list SD models")
shared.reload_hypernetworks()
+ startup_timer.record("reload hypernetworks")
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
@@ -316,6 +327,7 @@ def webui():
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+ startup_timer.record("initialize extra networks")
if __name__ == "__main__":