aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/on_pull_request.yaml2
-rw-r--r--.github/workflows/run_tests.yaml10
-rw-r--r--.gitignore1
-rw-r--r--CHANGELOG.md167
-rw-r--r--README.md8
-rw-r--r--configs/sd_xl_inpaint.yaml98
-rw-r--r--extensions-builtin/Lora/lyco_helpers.py47
-rw-r--r--extensions-builtin/Lora/network.py2
-rw-r--r--extensions-builtin/Lora/network_full.py4
-rw-r--r--extensions-builtin/Lora/network_glora.py10
-rw-r--r--extensions-builtin/Lora/network_hada.py12
-rw-r--r--extensions-builtin/Lora/network_ia3.py2
-rw-r--r--extensions-builtin/Lora/network_lokr.py18
-rw-r--r--extensions-builtin/Lora/network_lora.py6
-rw-r--r--extensions-builtin/Lora/network_norm.py4
-rw-r--r--extensions-builtin/Lora/network_oft.py82
-rw-r--r--extensions-builtin/Lora/networks.py42
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py2
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py9
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py12
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py68
-rw-r--r--extensions-builtin/ScuNET/scunet_model_arch.py268
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py156
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch.py867
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch_v2.py1017
-rw-r--r--extensions-builtin/extra-options-section/scripts/extra_options_section.py20
-rw-r--r--extensions-builtin/hypertile/hypertile.py351
-rw-r--r--extensions-builtin/hypertile/scripts/hypertile_script.py109
-rw-r--r--extensions-builtin/hypertile/scripts/hypertile_xyz.py51
-rw-r--r--extensions-builtin/mobile/javascript/mobile.js2
-rw-r--r--extensions-builtin/soft-inpainting/scripts/soft_inpainting.py747
-rw-r--r--javascript/edit-attention.js2
-rw-r--r--javascript/extraNetworks.js101
-rw-r--r--javascript/imageviewer.js2
-rw-r--r--javascript/inputAccordion.js81
-rw-r--r--javascript/notification.js6
-rw-r--r--javascript/settings.js25
-rw-r--r--javascript/ui.js41
-rw-r--r--modules/api/api.py153
-rw-r--r--modules/api/models.py7
-rw-r--r--modules/cache.py2
-rw-r--r--modules/call_queue.py1
-rw-r--r--modules/cmd_args.py3
-rw-r--r--modules/codeformer/codeformer_arch.py276
-rw-r--r--modules/codeformer/vqgan_arch.py435
-rw-r--r--modules/codeformer_model.py158
-rw-r--r--modules/devices.py76
-rw-r--r--modules/errors.py22
-rw-r--r--modules/esrgan_model.py199
-rw-r--r--modules/esrgan_model_arch.py465
-rw-r--r--modules/extensions.py96
-rw-r--r--modules/face_restoration_utils.py180
-rw-r--r--modules/gfpgan_model.py151
-rw-r--r--modules/gradio_extensons.py10
-rw-r--r--modules/hat_model.py43
-rw-r--r--modules/images.py14
-rw-r--r--modules/img2img.py28
-rw-r--r--modules/import_hook.py11
-rw-r--r--modules/infotext.py (renamed from modules/generation_parameters_copypaste.py)111
-rw-r--r--modules/infotext_versions.py39
-rw-r--r--modules/initialize.py3
-rw-r--r--modules/initialize_util.py2
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/launch_utils.py51
-rw-r--r--modules/logging_config.py27
-rw-r--r--modules/mac_specific.py15
-rw-r--r--modules/modelloader.py84
-rw-r--r--modules/models/diffusion/ddpm_edit.py7
-rw-r--r--modules/options.py81
-rw-r--r--modules/paths.py1
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/postprocessing.py96
-rw-r--r--modules/processing.py218
-rw-r--r--modules/processing_scripts/refiner.py7
-rw-r--r--modules/processing_scripts/seed.py13
-rw-r--r--modules/progress.py22
-rw-r--r--modules/prompt_parser.py2
-rw-r--r--modules/realesrgan_model.py158
-rw-r--r--modules/rng.py2
-rw-r--r--modules/scripts.py212
-rw-r--r--modules/scripts_postprocessing.py86
-rw-r--r--modules/sd_disable_initialization.py2
-rw-r--r--modules/sd_hijack.py28
-rw-r--r--modules/sd_models.py77
-rw-r--r--modules/sd_models_config.py6
-rw-r--r--modules/sd_models_types.py5
-rw-r--r--modules/sd_models_xl.py11
-rw-r--r--modules/sd_samplers_cfg_denoiser.py21
-rw-r--r--modules/sd_samplers_extra.py2
-rw-r--r--modules/sd_samplers_timesteps.py3
-rw-r--r--modules/sd_samplers_timesteps_impl.py4
-rw-r--r--modules/sd_unet.py14
-rw-r--r--modules/shared_gradio_themes.py4
-rw-r--r--modules/shared_items.py18
-rw-r--r--modules/shared_options.py162
-rw-r--r--modules/shared_state.py7
-rw-r--r--modules/styles.py138
-rw-r--r--modules/sysinfo.py20
-rw-r--r--modules/textual_inversion/autocrop.py239
-rw-r--r--modules/textual_inversion/preprocess.py232
-rw-r--r--modules/textual_inversion/textual_inversion.py10
-rw-r--r--modules/textual_inversion/ui.py7
-rw-r--r--modules/torch_utils.py17
-rw-r--r--modules/txt2img.py2
-rw-r--r--modules/ui.py454
-rw-r--r--modules/ui_common.py19
-rw-r--r--modules/ui_extensions.py17
-rw-r--r--modules/ui_extra_networks.py42
-rw-r--r--modules/ui_extra_networks_checkpoints.py10
-rw-r--r--modules/ui_extra_networks_hypernets.py13
-rw-r--r--modules/ui_extra_networks_textual_inversion.py10
-rw-r--r--modules/ui_extra_networks_user_metadata.py6
-rw-r--r--modules/ui_gradio_extensions.py11
-rw-r--r--modules/ui_loadsave.py4
-rw-r--r--modules/ui_postprocessing.py20
-rw-r--r--modules/ui_prompt_styles.py4
-rw-r--r--modules/ui_toprow.py149
-rw-r--r--modules/upscaler.py9
-rw-r--r--modules/upscaler_utils.py140
-rw-r--r--modules/util.py12
-rw-r--r--modules/xlmr.py5
-rw-r--r--modules/xlmr_m18.py4
-rw-r--r--modules/xpu_specific.py124
-rw-r--r--pyproject.toml1
-rw-r--r--requirements.txt5
-rw-r--r--requirements_versions.txt9
-rw-r--r--script.js44
-rw-r--r--scripts/loopback.py6
-rw-r--r--scripts/postprocessing_caption.py30
-rw-r--r--scripts/postprocessing_codeformer.py16
-rw-r--r--scripts/postprocessing_create_flipped_copies.py32
-rw-r--r--scripts/postprocessing_focal_crop.py54
-rw-r--r--scripts/postprocessing_gfpgan.py13
-rw-r--r--scripts/postprocessing_split_oversized.py71
-rw-r--r--scripts/postprocessing_upscale.py12
-rw-r--r--scripts/processing_autosized_crop.py64
-rw-r--r--scripts/prompts_from_file.py17
-rw-r--r--scripts/xyz_grid.py28
-rw-r--r--style.css65
-rw-r--r--test/conftest.py15
-rw-r--r--test/test_face_restorers.py29
-rw-r--r--test/test_files/two-faces.jpgbin0 -> 14768 bytes
-rw-r--r--test/test_outputs/.gitkeep0
-rw-r--r--test/test_torch_utils.py19
-rw-r--r--webui-macos-env.sh2
-rw-r--r--webui.py2
-rwxr-xr-xwebui.sh38
147 files changed, 5397 insertions, 5281 deletions
diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml
index 78e608ee..9e44c806 100644
--- a/.github/workflows/on_pull_request.yaml
+++ b/.github/workflows/on_pull_request.yaml
@@ -20,7 +20,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies.
- name: Install Ruff
- run: pip install ruff==0.0.272
+ run: pip install ruff==0.1.6
- name: Run Ruff
run: ruff .
lint-js:
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 3dafaf8d..f42e4758 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -20,6 +20,12 @@ jobs:
cache-dependency-path: |
**/requirements*txt
launch.py
+ - name: Cache models
+ id: cache-models
+ uses: actions/cache@v3
+ with:
+ path: models
+ key: "2023-12-30"
- name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt
env:
@@ -33,6 +39,8 @@ jobs:
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1"
+ - name: Print installed packages
+ run: pip freeze
- name: Start test server
run: >
python -m coverage run
@@ -49,7 +57,7 @@ jobs:
2>&1 | tee output.txt &
- name: Run tests
run: |
- wait-for-it --service 127.0.0.1:7860 -t 600
+ wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
diff --git a/.gitignore b/.gitignore
index 09734267..6790e9ee 100644
--- a/.gitignore
+++ b/.gitignore
@@ -37,3 +37,4 @@ notification.mp3
/node_modules
/package-lock.json
/.coverage*
+/test/test_outputs
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1cd3572c..67429bbf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,170 @@
+## 1.7.0
+
+### Features:
+* settings tab rework: add search field, add categories, split UI settings page into many
+* add altdiffusion-m18 support ([#13364](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13364))
+* support inference with LyCORIS GLora networks ([#13610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13610))
+* add lora-embedding bundle system ([#13568](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13568))
+* option to move prompt from top row into generation parameters
+* add support for SSD-1B ([#13865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13865))
+* support inference with OFT networks ([#13692](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13692))
+* script metadata and DAG sorting mechanism ([#13944](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13944))
+* support HyperTile optimization ([#13948](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13948))
+* add support for SD 2.1 Turbo ([#14170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14170))
+* remove Train->Preprocessing tab and put all its functionality into Extras tab
+* initial IPEX support for Intel Arc GPU ([#14171](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14171))
+
+### Minor:
+* allow reading model hash from images in img2img batch mode ([#12767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12767))
+* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))
+* extra field for lora metadata viewer: `ss_output_name` ([#12838](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12838))
+* add action in settings page to calculate all SD checkpoint hashes ([#12909](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12909))
+* add button to copy prompt to style editor ([#12975](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12975))
+* add --skip-load-model-at-start option ([#13253](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13253))
+* write infotext to gif images
+* read infotext from gif images ([#13068](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13068))
+* allow configuring the initial state of InputAccordion in ui-config.json ([#13189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13189))
+* allow editing whitespace delimiters for ctrl+up/ctrl+down prompt editing ([#13444](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13444))
+* prevent accidentally closing popup dialogs ([#13480](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13480))
+* added option to play notification sound or not ([#13631](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13631))
+* show the preview image in the full screen image viewer if available ([#13459](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13459))
+* support for webui.settings.bat ([#13638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13638))
+* add an option to not print stack traces on ctrl+c
+* start/restart generation by Ctrl (Alt) + Enter ([#13644](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13644))
+* update prompts_from_file script to allow concatenating entries with the general prompt ([#13733](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13733))
+* added a visible checkbox to input accordion
+* added an option to hide all txt2img/img2img parameters in an accordion ([#13826](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13826))
+* added 'Path' sorting option for Extra network cards ([#13968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13968))
+* enable prompt hotkeys in style editor ([#13931](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13931))
+* option to show batch img2img results in UI ([#14009](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14009))
+* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page
+* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))
+* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))
+* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
+
+### Extensions and API:
+* update gradio to 3.41.2
+* support installed extensions list api ([#12774](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12774))
+* update pnginfo API to return dict with parsed values
+* add noisy latent to `ExtraNoiseParams` for callback ([#12856](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12856))
+* show extension datetime in UTC ([#12864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12864), [#12865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12865), [#13281](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13281))
+* add an option to choose how to combine hires fix and refiner
+* include program version in info response. ([#13135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13135))
+* sd_unet support for SDXL
+* patch DDPM.register_betas so that users can put given_betas in model yaml ([#13276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13276))
+* xyz_grid: add prepare ([#13266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13266))
+* allow multiple localization files with same language in extensions ([#13077](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13077))
+* add onEdit function for js and rework token-counter.js to use it
+* fix the key error exception when processing override_settings keys ([#13567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13567))
+* ability for extensions to return custom data via api in response.images ([#13463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13463))
+* call state.jobnext() before postproces*() ([#13762](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13762))
+* add option to set notification sound volume ([#13884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13884))
+* update Ruff to 0.1.6 ([#14059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14059))
+* add Block component creation callback ([#14119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14119))
+* catch uncaught exception with ui creation scripts ([#14120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14120))
+* use extension name for determining an extension is installed in the index ([#14063](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14063))
+* update is_installed() from launch_utils.py to fix reinstalling already installed packages ([#14192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14192))
+
+### Bug Fixes:
+* fix pix2pix producing bad results
+* fix defaults settings page breaking when any of main UI tabs are hidden
+* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
+* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working
+* prevent duplicate resize handler ([#12795](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12795))
+* small typo: vae resolve bug ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12797))
+* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12792))
+* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12780))
+* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs
+* hide --gradio-auth and --api-auth values from /internal/sysinfo report
+* add missing infotext for RNG in options ([#12819](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12819))
+* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))
+* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))
+* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12833), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))
+* get progressbar to display correctly in extensions tab
+* keep order in list of checkpoints when loading model that doesn't have a checksum
+* fix inpainting models in txt2img creating black pictures
+* fix generation params regex ([#12876](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12876))
+* fix batch img2img output dir with script ([#12926](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12926))
+* fix #13080 - Hypernetwork/TI preview generation ([#13084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13084))
+* fix bug with sigma min/max overrides. ([#12995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12995))
+* more accurate check for enabling cuDNN benchmark on 16XX cards ([#12924](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12924))
+* don't use multicond parser for negative prompt counter ([#13118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13118))
+* fix data-sort-name containing spaces ([#13412](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13412))
+* update card on correct tab when editing metadata ([#13411](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13411))
+* fix viewing/editing metadata when filename contains an apostrophe ([#13395](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13395))
+* fix: --sd_model in "Prompts from file or textbox" script is not working ([#13302](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13302))
+* better Support for Portable Git ([#13231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13231))
+* fix issues when webui_dir is not work_dir ([#13210](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13210))
+* fix: lora-bias-backup don't reset cache ([#13178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13178))
+* account for customizable extra network separators whyen removing extra network text from the prompt ([#12877](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12877))
+* re fix batch img2img output dir with script ([#13170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13170))
+* fix `--ckpt-dir` path separator and option use `short name` for checkpoint dropdown ([#13139](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13139))
+* consolidated allowed preview formats, Fix extra network `.gif` not woking as preview ([#13121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13121))
+* fix venv_dir=- environment variable not working as expected on linux ([#13469](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13469))
+* repair unload sd checkpoint button
+* edit-attention fixes ([#13533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13533))
+* fix bug when using --gfpgan-models-path ([#13718](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13718))
+* properly apply sort order for extra network cards when selected from dropdown
+* fixes generation restart not working for some users when 'Ctrl+Enter' is pressed ([#13962](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13962))
+* thread safe extra network list_items ([#13014](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13014))
+* fix not able to exit metadata popup when pop up is too big ([#14156](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14156))
+* fix auto focal point crop for opencv >= 4.8 ([#14121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14121))
+* make 'use-cpu all' actually apply to 'all' ([#14131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14131))
+* extras tab batch: actually use original filename
+* make webui not crash when running with --disable-all-extensions option
+
+### Other:
+* non-local condition ([#12814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12814))
+* fix minor typos ([#12827](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12827))
+* remove xformers Python version check ([#12842](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12842))
+* style: file-metadata word-break ([#12837](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12837))
+* revert SGM noise multiplier change for img2img because it breaks hires fix
+* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))
+* [RC 1.6.0 - zoom is partly hidden] Update style.css ([#12839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12839))
+* chore: change extension time format ([#12851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12851))
+* WEBUI.SH - Use torch 2.1.0 release candidate for Navi 3 ([#12929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12929))
+* add Fallback at images.read_info_from_image if exif data was invalid ([#13028](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13028))
+* update cmd arg description ([#12986](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12986))
+* fix: update shared.opts.data when add_option ([#12957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12957), [#13213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13213))
+* restore missing tooltips ([#12976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12976))
+* use default dropdown padding on mobile ([#12880](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12880))
+* put enable console prompts option into settings from commandline args ([#13119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13119))
+* fix some deprecated types ([#12846](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12846))
+* bump to torchsde==0.2.6 ([#13418](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13418))
+* update dragdrop.js ([#13372](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13372))
+* use orderdict as lru cache:opt/bug ([#13313](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13313))
+* XYZ if not include sub grids do not save sub grid ([#13282](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13282))
+* initialize state.time_start befroe state.job_count ([#13229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13229))
+* fix fieldname regex ([#13458](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13458))
+* change denoising_strength default to None. ([#13466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13466))
+* fix regression ([#13475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13475))
+* fix IndexError ([#13630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13630))
+* fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty ([#13535](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13535))
+* update bug_report.yml ([#12991](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12991))
+* requirements_versions httpx==0.24.1 ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))
+* fix parenthesis auto selection ([#13829](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13829))
+* fix #13796 ([#13797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13797))
+* corrected a typo in `modules/cmd_args.py` ([#13855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13855))
+* feat: fix randn found element of type float at pos 2 ([#14004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14004))
+* adds tqdm handler to logging_config.py for progress bar integration ([#13996](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13996))
+* hotfix: call shared.state.end() after postprocessing done ([#13977](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13977))
+* fix dependency address patch 1 ([#13929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13929))
+* save sysinfo as .json ([#14035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14035))
+* move exception_records related methods to errors.py ([#14084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14084))
+* compatibility ([#13936](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13936))
+* json.dump(ensure_ascii=False) ([#14108](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14108))
+* dir buttons start with / so only the correct dir will be shown and no… ([#13957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13957))
+* alternate implementation for unet forward replacement that does not depend on hijack being applied
+* re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46 ([#14178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14178))
+* fix `save_samples` being checked early when saving masked composite ([#14177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14177))
+* slight optimization for mask and mask_composite ([#14181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14181))
+* add import_hook hack to work around basicsr/torchvision incompatibility ([#14186](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186))
+
+## 1.6.1
+
+### Bug Fixes:
+ * fix an error causing the webui to fail to start ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))
+
## 1.6.0
### Features:
diff --git a/README.md b/README.md
index c7a4e363..9f9f33b1 100644
--- a/README.md
+++ b/README.md
@@ -91,6 +91,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
+- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
@@ -120,7 +121,9 @@ Alternatively, use online services (like Google Colab):
# Debian-based:
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
# Red Hat-based:
-sudo dnf install wget git python3
+sudo dnf install wget git python3 gperftools-libs libglvnd-glx
+# openSUSE-based:
+sudo zypper install wget git python3 libtcmalloc4 libglvnd
# Arch-based:
sudo pacman -S wget git python3
```
@@ -146,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
-- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
+- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
- CodeFormer - https://github.com/sczhou/CodeFormer
@@ -173,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
+- Hypertile - tfernd - https://github.com/tfernd/HyperTile
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml
new file mode 100644
index 00000000..3bad3721
--- /dev/null
+++ b/configs/sd_xl_inpaint.yaml
@@ -0,0 +1,98 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ adm_in_channels: 2816
+ num_classes: sequential
+ use_checkpoint: True
+ in_channels: 9
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
+ context_dim: 2048
+ spatial_transformer_attn_type: softmax-xformers
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ layer: hidden
+ layer_idx: 11
+ # crossattn and vector cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
+ params:
+ arch: ViT-bigG-14
+ version: laion2b_s39b_b160k
+ freeze: True
+ layer: penultimate
+ always_return_pooled: True
+ legacy: False
+ # vector cond
+ - is_trainable: False
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: target_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
index 279b34bc..1679a0ce 100644
--- a/extensions-builtin/Lora/lyco_helpers.py
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
+
+
+# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
+def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
+ '''
+ return a tuple of two value of input dimension decomposed by the number closest to factor
+ second value is higher or equal than first value.
+
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
+ secon value is a value for weight.
+
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
+
+ examples)
+ factor
+ -1 2 4 8 16 ...
+ 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
+ 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
+ 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
+ 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
+ 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
+ 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
+ '''
+
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m<n:
+ new_m = m + 1
+ while dimension%new_m != 0:
+ new_m += 1
+ new_n = dimension // new_m
+ if new_m + new_n > length or new_m>factor:
+ break
+ else:
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
+
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 6021fd8d..a62e5eff 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -137,7 +137,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index bf6930e9..f221c95f 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
- updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
index 492d4870..efe5c681 100644
--- a/extensions-builtin/Lora/network_glora.py
+++ b/extensions-builtin/Lora/network_glora.py
@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
- updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+ updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py
index 5fcb0695..d95a0fd1 100644
--- a/extensions-builtin/Lora/network_hada.py
+++ b/extensions-builtin/Lora/network_hada.py
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py
index 7edc4249..96faeaf3 100644
--- a/extensions-builtin/Lora/network_ia3.py
+++ b/extensions-builtin/Lora/network_ia3.py
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
- w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+ w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
index 340acdab..fcdaeafd 100644
--- a/extensions-builtin/Lora/network_lokr.py
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight):
if self.w1 is not None:
- w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = self.w1.to(orig_weight.device)
else:
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
- w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index 26c0a72c..4cc40295 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, orig_weight):
- up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ up = self.up_model.weight.to(orig_weight.device)
+ down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
- mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
index ce450158..d25afcbb 100644
--- a/extensions-builtin/Lora/network_norm.py
+++ b/extensions-builtin/Lora/network_norm.py
@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
- updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None:
- ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
new file mode 100644
index 00000000..fa647020
--- /dev/null
+++ b/extensions-builtin/Lora/network_oft.py
@@ -0,0 +1,82 @@
+import torch
+import network
+from lyco_helpers import factorization
+from einops import rearrange
+
+
+class ModuleTypeOFT(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
+ return NetworkModuleOFT(net, weights)
+
+ return None
+
+# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
+# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
+class NetworkModuleOFT(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+
+ super().__init__(net, weights)
+
+ self.lin_module = None
+ self.org_module: list[torch.Module] = [self.sd_module]
+
+ self.scale = 1.0
+
+ # kohya-ss
+ if "oft_blocks" in weights.w.keys():
+ self.is_kohya = True
+ self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
+ self.alpha = weights.w["alpha"] # alpha is constraint
+ self.dim = self.oft_blocks.shape[0] # lora dim
+ # LyCORIS
+ elif "oft_diag" in weights.w.keys():
+ self.is_kohya = False
+ self.oft_blocks = weights.w["oft_diag"]
+ # self.alpha is unused
+ self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
+
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+ is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
+
+ if is_linear:
+ self.out_dim = self.sd_module.out_features
+ elif is_conv:
+ self.out_dim = self.sd_module.out_channels
+ elif is_other_linear:
+ self.out_dim = self.sd_module.embed_dim
+
+ if self.is_kohya:
+ self.constraint = self.alpha * self.out_dim
+ self.num_blocks = self.dim
+ self.block_size = self.out_dim // self.dim
+ else:
+ self.constraint = None
+ self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
+
+ def calc_updown(self, orig_weight):
+ oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ eye = torch.eye(self.block_size, device=self.oft_blocks.device)
+
+ if self.is_kohya:
+ block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
+ norm_Q = torch.norm(block_Q.flatten())
+ new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
+ block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
+ oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
+
+ R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ # This errors out for MultiheadAttention, might need to be handled up-stream
+ merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R,
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+
+ updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ output_shape = orig_weight.shape
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 60d8dec4..72ebd624 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -1,3 +1,4 @@
+import gradio as gr
import logging
import os
import re
@@ -11,6 +12,7 @@ import network_ia3
import network_lokr
import network_full
import network_norm
+import network_oft
import torch
from typing import Union
@@ -28,6 +30,7 @@ module_types = [
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
+ network_oft.ModuleTypeOFT(),
]
@@ -157,7 +160,8 @@ def load_network(name, network_on_disk):
bundle_embeddings = {}
for key_network, weight in sd.items():
- key_network_without_network_parts, network_part = key_network.split(".", 1)
+ key_network_without_network_parts, _, network_part = key_network.partition(".")
+
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {})
@@ -189,6 +193,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # kohya_ss OFT module
+ elif sd_module is None and "oft_unet" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ # KohakuBlueLeaf OFT module
+ if sd_module is None and "oft_diag" in key:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
@@ -300,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
+ lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
+ sd_hijack.model_hijack.comments.append(lora_not_found_message)
+ if shared.opts.lora_not_found_warning_console:
+ print(f'\n{lora_not_found_message}\n')
+ if shared.opts.lora_not_found_gradio_warning:
+ gr.Warning(lora_not_found_message)
purge_networks_from_memory()
@@ -375,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
- updown, ex_bias = module.calc_updown(self.weight)
+ if getattr(self, 'fp16_weight', None) is None:
+ weight = self.weight
+ bias = self.bias
+ else:
+ weight = self.fp16_weight.clone().to(self.weight.device)
+ bias = getattr(self, 'fp16_bias', None)
+ if bias is not None:
+ bias = bias.clone().to(self.bias.device)
+ updown, ex_bias = module.calc_updown(weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index ef23968c..1518f7e5 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
+ "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
+ "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
index c7011909..3160aecf 100644
--- a/extensions-builtin/Lora/ui_edit_user_metadata.py
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.slider_preferred_weight = None
self.edit_notes = None
- def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
+ def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
+ user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata)
@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
+ user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
-
+ self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
+ self.edit_negative_text,
row_random_prompt,
random_prompt,
]
@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
+ self.edit_negative_text,
self.edit_notes,
]
+
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 55409a78..e714fac4 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)
+ if lora_on_disk is None:
+ return
path, ext = os.path.splitext(lora_on_disk.filename)
@@ -43,6 +45,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
+ negative_prompt = item["user_metadata"].get("negative text")
+ item["negative_prompt"] = quote_js("")
+ if negative_prompt:
+ item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
+
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
@@ -66,9 +73,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item
def list_items(self):
- for index, name in enumerate(networks.available_networks):
+ # instantiate a list to protect against concurrent modification
+ names = list(networks.available_networks)
+ for index, name in enumerate(names):
item = self.create_item(name, index)
-
if item is not None:
yield item
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 167d2f64..f799cb76 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -3,14 +3,11 @@ import sys
import PIL.Image
import numpy as np
import torch
-from tqdm import tqdm
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet
-
-from modules.modelloader import load_file_from_url
from modules.shared import opts
+from modules.upscaler_utils import tiled_upscale_2
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -42,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
- @staticmethod
- @torch.no_grad()
- def tiled_inference(img, model):
- # test the image tile by tile
- h, w = img.shape[2:]
- tile = opts.SCUNET_tile
- tile_overlap = opts.SCUNET_tile_overlap
- if tile == 0:
- return model(img)
-
- device = devices.get_device_for('scunet')
- assert tile % 8 == 0, "tile size should be a multiple of window_size"
- sf = 1
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
- for h_idx in h_idx_list:
-
- for w_idx in w_idx_list:
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
-
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc()
@@ -106,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
- torch_output = self.tiled_inference(torch_img, model).squeeze(0)
+ with torch.no_grad():
+ torch_output = tiled_upscale_2(
+ torch_img,
+ model,
+ tile_size=opts.SCUNET_tile,
+ tile_overlap=opts.SCUNET_tile_overlap,
+ scale=1,
+ device=devices.get_device_for('scunet'),
+ desc="ScuNET tiles",
+ ).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
@@ -120,17 +85,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
device = devices.get_device_for('scunet')
if path.startswith("http"):
# TODO: this doesn't use `path` at all?
- filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
+ filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for _, v in model.named_parameters():
- v.requires_grad = False
- model = model.to(device)
-
- return model
+ return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings():
diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
deleted file mode 100644
index b51a8806..00000000
--- a/extensions-builtin/ScuNET/scunet_model_arch.py
+++ /dev/null
@@ -1,268 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import torch
-import torch.nn as nn
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from timm.models.layers import trunc_normal_, DropPath
-
-
-class WMSA(nn.Module):
- """ Self-attention module in Swin Transformer
- """
-
- def __init__(self, input_dim, output_dim, head_dim, window_size, type):
- super(WMSA, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.head_dim = head_dim
- self.scale = self.head_dim ** -0.5
- self.n_heads = input_dim // head_dim
- self.window_size = window_size
- self.type = type
- self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
-
- self.relative_position_params = nn.Parameter(
- torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
-
- self.linear = nn.Linear(self.input_dim, self.output_dim)
-
- trunc_normal_(self.relative_position_params, std=.02)
- self.relative_position_params = torch.nn.Parameter(
- self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
- 2).transpose(
- 0, 1))
-
- def generate_mask(self, h, w, p, shift):
- """ generating the mask of SW-MSA
- Args:
- shift: shift parameters in CyclicShift.
- Returns:
- attn_mask: should be (1 1 w p p),
- """
- # supporting square.
- attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
- if self.type == 'W':
- return attn_mask
-
- s = p - shift
- attn_mask[-1, :, :s, :, s:, :] = True
- attn_mask[-1, :, s:, :, :s, :] = True
- attn_mask[:, -1, :, :s, :, s:] = True
- attn_mask[:, -1, :, s:, :, :s] = True
- attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
- return attn_mask
-
- def forward(self, x):
- """ Forward pass of Window Multi-head Self-attention module.
- Args:
- x: input tensor with shape of [b h w c];
- attn_mask: attention mask, fill -inf where the value is True;
- Returns:
- output: tensor shape [b h w c]
- """
- if self.type != 'W':
- x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
-
- x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
- h_windows = x.size(1)
- w_windows = x.size(2)
- # square validation
- # assert h_windows == w_windows
-
- x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
- qkv = self.embedding_layer(x)
- q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
- sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
- # Adding learnable relative embedding
- sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
- # Using Attn Mask to distinguish different subwindows.
- if self.type != 'W':
- attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
- sim = sim.masked_fill_(attn_mask, float("-inf"))
-
- probs = nn.functional.softmax(sim, dim=-1)
- output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
- output = rearrange(output, 'h b w p c -> b w p (h c)')
- output = self.linear(output)
- output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
-
- if self.type != 'W':
- output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
-
- return output
-
- def relative_embedding(self):
- cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
- relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
- # negative is allowed
- return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
-
-
-class Block(nn.Module):
- def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer Block
- """
- super(Block, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- assert type in ['W', 'SW']
- self.type = type
- if input_resolution <= window_size:
- self.type = 'W'
-
- self.ln1 = nn.LayerNorm(input_dim)
- self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.ln2 = nn.LayerNorm(input_dim)
- self.mlp = nn.Sequential(
- nn.Linear(input_dim, 4 * input_dim),
- nn.GELU(),
- nn.Linear(4 * input_dim, output_dim),
- )
-
- def forward(self, x):
- x = x + self.drop_path(self.msa(self.ln1(x)))
- x = x + self.drop_path(self.mlp(self.ln2(x)))
- return x
-
-
-class ConvTransBlock(nn.Module):
- def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer and Conv Block
- """
- super(ConvTransBlock, self).__init__()
- self.conv_dim = conv_dim
- self.trans_dim = trans_dim
- self.head_dim = head_dim
- self.window_size = window_size
- self.drop_path = drop_path
- self.type = type
- self.input_resolution = input_resolution
-
- assert self.type in ['W', 'SW']
- if self.input_resolution <= self.window_size:
- self.type = 'W'
-
- self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
- self.type, self.input_resolution)
- self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
- self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
-
- self.conv_block = nn.Sequential(
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
- nn.ReLU(True),
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
- )
-
- def forward(self, x):
- conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
- conv_x = self.conv_block(conv_x) + conv_x
- trans_x = Rearrange('b c h w -> b h w c')(trans_x)
- trans_x = self.trans_block(trans_x)
- trans_x = Rearrange('b h w c -> b c h w')(trans_x)
- res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
- x = x + res
-
- return x
-
-
-class SCUNet(nn.Module):
- # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
- def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
- super(SCUNet, self).__init__()
- if config is None:
- config = [2, 2, 2, 2, 2, 2, 2]
- self.config = config
- self.dim = dim
- self.head_dim = 32
- self.window_size = 8
-
- # drop path rate for each layer
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
-
- self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
-
- begin = 0
- self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[0])] + \
- [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
-
- begin += config[0]
- self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[1])] + \
- [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
-
- begin += config[1]
- self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[2])] + \
- [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
-
- begin += config[2]
- self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 8)
- for i in range(config[3])]
-
- begin += config[3]
- self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[4])]
-
- begin += config[4]
- self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[5])]
-
- begin += config[5]
- self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[6])]
-
- self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
-
- self.m_head = nn.Sequential(*self.m_head)
- self.m_down1 = nn.Sequential(*self.m_down1)
- self.m_down2 = nn.Sequential(*self.m_down2)
- self.m_down3 = nn.Sequential(*self.m_down3)
- self.m_body = nn.Sequential(*self.m_body)
- self.m_up3 = nn.Sequential(*self.m_up3)
- self.m_up2 = nn.Sequential(*self.m_up2)
- self.m_up1 = nn.Sequential(*self.m_up1)
- self.m_tail = nn.Sequential(*self.m_tail)
- # self.apply(self._init_weights)
-
- def forward(self, x0):
-
- h, w = x0.size()[-2:]
- paddingBottom = int(np.ceil(h / 64) * 64 - h)
- paddingRight = int(np.ceil(w / 64) * 64 - w)
- x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
-
- x1 = self.m_head(x0)
- x2 = self.m_down1(x1)
- x3 = self.m_down2(x2)
- x4 = self.m_down3(x3)
- x = self.m_body(x4)
- x = self.m_up3(x + x4)
- x = self.m_up2(x + x3)
- x = self.m_up1(x + x2)
- x = self.m_tail(x + x1)
-
- x = x[..., :h, :w]
-
- return x
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index ae0d0e6a..8a555c79 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,20 +1,18 @@
+import logging
import sys
-import platform
import numpy as np
import torch
from PIL import Image
-from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts, state
-from swinir_model_arch import SwinIR
-from swinir_model_arch_v2 import Swin2SR
+from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
-device_swinir = devices.get_device_for('swinir')
+logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler):
@@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data)
self.scalers = scalers
- def do_upscale(self, img, model_file):
- use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
- and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
+ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
current_config = (model_file, opts.SWIN_tile)
- if use_compile and self._cached_model_config == current_config:
+ device = self._get_device()
+
+ if self._cached_model_config == current_config:
model = self._cached_model
else:
- self._cached_model = None
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
- model = model.to(device_swinir, dtype=devices.dtype)
- if use_compile:
- model = torch.compile(model)
- self._cached_model = model
- self._cached_model_config = current_config
- img = upscale(img, model)
+ self._cached_model = model
+ self._cached_model_config = current_config
+
+ img = upscale(
+ img,
+ model,
+ tile=opts.SWIN_tile,
+ tile_overlap=opts.SWIN_tile_overlap,
+ device=device,
+ )
devices.torch_gc()
return img
@@ -69,69 +70,55 @@ class UpscalerSwinIR(Upscaler):
)
else:
filename = path
- if filename.endswith(".v2.pth"):
- model = Swin2SR(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
- )
- params = None
- else:
- model = SwinIR(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
- params = "params_ema"
- pretrained_model = torch.load(filename)
- if params is not None:
- model.load_state_dict(pretrained_model[params], strict=True)
- else:
- model.load_state_dict(pretrained_model, strict=True)
- return model
+ model_descriptor = modelloader.load_spandrel_model(
+ filename,
+ device=self._get_device(),
+ dtype=devices.dtype,
+ expected_architecture="SwinIR",
+ )
+ if getattr(opts, 'SWIN_torch_compile', False):
+ try:
+ model_descriptor.model.compile()
+ except Exception:
+ logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
+ return model_descriptor
+
+ def _get_device(self):
+ return devices.get_device_for('swinir')
def upscale(
- img,
- model,
- tile=None,
- tile_overlap=None,
- window_size=8,
- scale=4,
+ img,
+ model,
+ *,
+ tile: int,
+ tile_overlap: int,
+ window_size=8,
+ scale=4,
+ device,
):
- tile = tile or opts.SWIN_tile
- tile_overlap = tile_overlap or opts.SWIN_tile_overlap
-
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
+ img = img.unsqueeze(0).to(device, dtype=devices.dtype)
with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(img, model, tile, tile_overlap, window_size, scale)
+ output = tiled_upscale_2(
+ img,
+ model,
+ tile_size=tile,
+ tile_overlap=tile_overlap,
+ scale=scale,
+ device=device,
+ desc="SwinIR tiles",
+ )
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
@@ -142,51 +129,12 @@ def upscale(
return Image.fromarray(output, "RGB")
-def inference(img, model, tile, tile_overlap, window_size, scale):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
- for h_idx in h_idx_list:
- if state.interrupted or state.skipped:
- break
-
- for w_idx in w_idx_list:
- if state.interrupted or state.skipped:
- break
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
-
def on_ui_settings():
import gradio as gr
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
- if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
- shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
+ shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
deleted file mode 100644
index 93b93274..00000000
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ /dev/null
@@ -1,867 +0,0 @@
-# -----------------------------------------------------------------------------------
-# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
-# Originally Written by Ze Liu, Modified by Jingyun Liang.
-# -----------------------------------------------------------------------------------
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- flops = 0
- H, W = self.img_size
- if self.norm is not None:
- flops += H * W * self.embed_dim
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class SwinIR(nn.Module):
- r""" SwinIR
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(SwinIR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if self.upscale == 4:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- if self.upscale == 4:
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
-
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for layer in self.layers:
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = SwinIR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
deleted file mode 100644
index dad22cca..00000000
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ /dev/null
@@ -1,1017 +0,0 @@
-# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
-# Written by Conde and Choi et al.
-# -----------------------------------------------------------------------------------
-
-import math
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=(0, 0)):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim))
- self.v_bias = nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, ' \
- f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size))
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- #assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- pretrained_window_size=0):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0)
- nn.init.constant_(blk.norm1.weight, 0)
- nn.init.constant_(blk.norm2.bias, 0)
- nn.init.constant_(blk.norm2.weight, 0)
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-
-class Swin2SR(nn.Module):
- r""" Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
- window_size=7, mlp_ratio=4., qkv_bias=True,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(Swin2SR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
-
- if self.upsampler == 'pixelshuffle_hf':
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffle_aux':
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffle_hf':
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- assert self.upscale == 4, 'only support x4 now.'
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffle_aux':
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == 'pixelshuffle_hf':
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
- else:
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for layer in self.layers:
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = Swin2SR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
index 983f87ff..8aa901fd 100644
--- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py
+++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
@@ -1,7 +1,7 @@
import math
import gradio as gr
-from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
+from modules import scripts, shared, ui_components, ui_settings, infotext
from modules.ui_components import FormColumn
@@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
self.setting_names = []
self.infotext_fields = []
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
+ elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
- mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
+ mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
- with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
+ with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
@@ -64,11 +65,14 @@ class ExtraOptionsSection(scripts.Script):
p.override_settings[name] = value
-shared.options_templates.update(shared.options_section(('ui', "User interface"), {
- "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
- "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
- "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
- "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
+shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), {
+ "settings_in_ui": shared.OptionHTML("""
+This page allows you to add some settings to the main interface of txt2img and img2img tabs.
+"""),
+ "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
+ "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
+ "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
+ "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
}))
diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py
new file mode 100644
index 00000000..0f40e2d3
--- /dev/null
+++ b/extensions-builtin/hypertile/hypertile.py
@@ -0,0 +1,351 @@
+"""
+Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
+Warn: The patch works well only if the input image has a width and height that are multiples of 128
+Original author: @tfernd Github: https://github.com/tfernd/HyperTile
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Callable
+
+from functools import wraps, cache
+
+import math
+import torch.nn as nn
+import random
+
+from einops import rearrange
+
+
+@dataclass
+class HypertileParams:
+ depth = 0
+ layer_name = ""
+ tile_size: int = 0
+ swap_size: int = 0
+ aspect_ratio: float = 1.0
+ forward = None
+ enabled = False
+
+
+
+# TODO add SD-XL layers
+DEPTH_LAYERS = {
+ 0: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.1.1.transformer_blocks.0.attn1",
+ "input_blocks.2.1.transformer_blocks.0.attn1",
+ "output_blocks.9.1.transformer_blocks.0.attn1",
+ "output_blocks.10.1.transformer_blocks.0.attn1",
+ "output_blocks.11.1.transformer_blocks.0.attn1",
+ # SD 1.5 VAE
+ "decoder.mid_block.attentions.0",
+ "decoder.mid.attn_1",
+ ],
+ 1: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.0.attn1",
+ "input_blocks.5.1.transformer_blocks.0.attn1",
+ "output_blocks.6.1.transformer_blocks.0.attn1",
+ "output_blocks.7.1.transformer_blocks.0.attn1",
+ "output_blocks.8.1.transformer_blocks.0.attn1",
+ ],
+ 2: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.7.1.transformer_blocks.0.attn1",
+ "input_blocks.8.1.transformer_blocks.0.attn1",
+ "output_blocks.3.1.transformer_blocks.0.attn1",
+ "output_blocks.4.1.transformer_blocks.0.attn1",
+ "output_blocks.5.1.transformer_blocks.0.attn1",
+ ],
+ 3: [
+ # SD 1.5 U-Net (diffusers)
+ "mid_block.attentions.0.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "middle_block.1.transformer_blocks.0.attn1",
+ ],
+}
+# XL layers, thanks for GitHub@gel-crabs for the help
+DEPTH_LAYERS_XL = {
+ 0: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.0.attn1",
+ "input_blocks.5.1.transformer_blocks.0.attn1",
+ "output_blocks.3.1.transformer_blocks.0.attn1",
+ "output_blocks.4.1.transformer_blocks.0.attn1",
+ "output_blocks.5.1.transformer_blocks.0.attn1",
+ # SD 1.5 VAE
+ "decoder.mid_block.attentions.0",
+ "decoder.mid.attn_1",
+ ],
+ 1: [
+ # SD 1.5 U-Net (diffusers)
+ #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.1.attn1",
+ "input_blocks.5.1.transformer_blocks.1.attn1",
+ "output_blocks.3.1.transformer_blocks.1.attn1",
+ "output_blocks.4.1.transformer_blocks.1.attn1",
+ "output_blocks.5.1.transformer_blocks.1.attn1",
+ "input_blocks.7.1.transformer_blocks.0.attn1",
+ "input_blocks.8.1.transformer_blocks.0.attn1",
+ "output_blocks.0.1.transformer_blocks.0.attn1",
+ "output_blocks.1.1.transformer_blocks.0.attn1",
+ "output_blocks.2.1.transformer_blocks.0.attn1",
+ "input_blocks.7.1.transformer_blocks.1.attn1",
+ "input_blocks.8.1.transformer_blocks.1.attn1",
+ "output_blocks.0.1.transformer_blocks.1.attn1",
+ "output_blocks.1.1.transformer_blocks.1.attn1",
+ "output_blocks.2.1.transformer_blocks.1.attn1",
+ "input_blocks.7.1.transformer_blocks.2.attn1",
+ "input_blocks.8.1.transformer_blocks.2.attn1",
+ "output_blocks.0.1.transformer_blocks.2.attn1",
+ "output_blocks.1.1.transformer_blocks.2.attn1",
+ "output_blocks.2.1.transformer_blocks.2.attn1",
+ "input_blocks.7.1.transformer_blocks.3.attn1",
+ "input_blocks.8.1.transformer_blocks.3.attn1",
+ "output_blocks.0.1.transformer_blocks.3.attn1",
+ "output_blocks.1.1.transformer_blocks.3.attn1",
+ "output_blocks.2.1.transformer_blocks.3.attn1",
+ "input_blocks.7.1.transformer_blocks.4.attn1",
+ "input_blocks.8.1.transformer_blocks.4.attn1",
+ "output_blocks.0.1.transformer_blocks.4.attn1",
+ "output_blocks.1.1.transformer_blocks.4.attn1",
+ "output_blocks.2.1.transformer_blocks.4.attn1",
+ "input_blocks.7.1.transformer_blocks.5.attn1",
+ "input_blocks.8.1.transformer_blocks.5.attn1",
+ "output_blocks.0.1.transformer_blocks.5.attn1",
+ "output_blocks.1.1.transformer_blocks.5.attn1",
+ "output_blocks.2.1.transformer_blocks.5.attn1",
+ "input_blocks.7.1.transformer_blocks.6.attn1",
+ "input_blocks.8.1.transformer_blocks.6.attn1",
+ "output_blocks.0.1.transformer_blocks.6.attn1",
+ "output_blocks.1.1.transformer_blocks.6.attn1",
+ "output_blocks.2.1.transformer_blocks.6.attn1",
+ "input_blocks.7.1.transformer_blocks.7.attn1",
+ "input_blocks.8.1.transformer_blocks.7.attn1",
+ "output_blocks.0.1.transformer_blocks.7.attn1",
+ "output_blocks.1.1.transformer_blocks.7.attn1",
+ "output_blocks.2.1.transformer_blocks.7.attn1",
+ "input_blocks.7.1.transformer_blocks.8.attn1",
+ "input_blocks.8.1.transformer_blocks.8.attn1",
+ "output_blocks.0.1.transformer_blocks.8.attn1",
+ "output_blocks.1.1.transformer_blocks.8.attn1",
+ "output_blocks.2.1.transformer_blocks.8.attn1",
+ "input_blocks.7.1.transformer_blocks.9.attn1",
+ "input_blocks.8.1.transformer_blocks.9.attn1",
+ "output_blocks.0.1.transformer_blocks.9.attn1",
+ "output_blocks.1.1.transformer_blocks.9.attn1",
+ "output_blocks.2.1.transformer_blocks.9.attn1",
+ ],
+ 2: [
+ # SD 1.5 U-Net (diffusers)
+ "mid_block.attentions.0.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "middle_block.1.transformer_blocks.0.attn1",
+ "middle_block.1.transformer_blocks.1.attn1",
+ "middle_block.1.transformer_blocks.2.attn1",
+ "middle_block.1.transformer_blocks.3.attn1",
+ "middle_block.1.transformer_blocks.4.attn1",
+ "middle_block.1.transformer_blocks.5.attn1",
+ "middle_block.1.transformer_blocks.6.attn1",
+ "middle_block.1.transformer_blocks.7.attn1",
+ "middle_block.1.transformer_blocks.8.attn1",
+ "middle_block.1.transformer_blocks.9.attn1",
+ ],
+ 3 : [] # TODO - separate layers for SD-XL
+}
+
+
+RNG_INSTANCE = random.Random()
+
+@cache
+def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
+ """
+ Returns divisors of value that
+ x * min_value <= value
+ in big -> small order, amount of divisors is limited by max_options
+ """
+ max_options = max(1, max_options) # at least 1 option should be returned
+ min_value = min(min_value, value)
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
+ return ns
+
+
+def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
+ """
+ Returns a random divisor of value that
+ x * min_value <= value
+ if max_options is 1, the behavior is deterministic
+ """
+ ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
+ idx = RNG_INSTANCE.randint(0, len(ns) - 1)
+
+ return ns[idx]
+
+
+def set_hypertile_seed(seed: int) -> None:
+ RNG_INSTANCE.seed(seed)
+
+
+@cache
+def largest_tile_size_available(width: int, height: int) -> int:
+ """
+ Calculates the largest tile size available for a given width and height
+ Tile size is always a power of 2
+ """
+ gcd = math.gcd(width, height)
+ largest_tile_size_available = 1
+ while gcd % (largest_tile_size_available * 2) == 0:
+ largest_tile_size_available *= 2
+ return largest_tile_size_available
+
+
+def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
+ """
+ Finds h and w such that h*w = hw and h/w = aspect_ratio
+ We check all possible divisors of hw and return the closest to the aspect ratio
+ """
+ divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
+ pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
+ ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
+ closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
+ closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
+ return closest_pair
+
+
+@cache
+def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
+ """
+ Finds h and w such that h*w = hw and h/w = aspect_ratio
+ """
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
+ # find h and w such that h*w = hw and h/w = aspect_ratio
+ if h * w != hw:
+ w_candidate = hw / h
+ # check if w is an integer
+ if not w_candidate.is_integer():
+ h_candidate = hw / w
+ # check if h is an integer
+ if not h_candidate.is_integer():
+ return iterative_closest_divisors(hw, aspect_ratio)
+ else:
+ h = int(h_candidate)
+ else:
+ w = int(w_candidate)
+ return h, w
+
+
+def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
+
+ @wraps(params.forward)
+ def wrapper(*args, **kwargs):
+ if not params.enabled:
+ return params.forward(*args, **kwargs)
+
+ latent_tile_size = max(128, params.tile_size) // 8
+ x = args[0]
+
+ # VAE
+ if x.ndim == 4:
+ b, c, h, w = x.shape
+
+ nh = random_divisor(h, latent_tile_size, params.swap_size)
+ nw = random_divisor(w, latent_tile_size, params.swap_size)
+
+ if nh * nw > 1:
+ x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
+
+ out = params.forward(x, *args[1:], **kwargs)
+
+ if nh * nw > 1:
+ out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
+
+ # U-Net
+ else:
+ hw: int = x.size(1)
+ h, w = find_hw_candidates(hw, params.aspect_ratio)
+ assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
+
+ factor = 2 ** params.depth if scale_depth else 1
+ nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
+ nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
+
+ if nh * nw > 1:
+ x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
+
+ out = params.forward(x, *args[1:], **kwargs)
+
+ if nh * nw > 1:
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
+
+ return out
+
+ return wrapper
+
+
+def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
+ hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
+ if hypertile_layers is None:
+ if not enable:
+ return
+
+ hypertile_layers = {}
+ layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
+
+ for depth in range(4):
+ for layer_name, module in model.named_modules():
+ if any(layer_name.endswith(try_name) for try_name in layers[depth]):
+ params = HypertileParams()
+ module.__webui_hypertile_params = params
+ params.forward = module.forward
+ params.depth = depth
+ params.layer_name = layer_name
+ module.forward = self_attn_forward(params)
+
+ hypertile_layers[layer_name] = 1
+
+ model.__webui_hypertile_layers = hypertile_layers
+
+ aspect_ratio = width / height
+ tile_size = min(largest_tile_size_available(width, height), tile_size_max)
+
+ for layer_name, module in model.named_modules():
+ if layer_name in hypertile_layers:
+ params = module.__webui_hypertile_params
+
+ params.tile_size = tile_size
+ params.swap_size = swap_size
+ params.aspect_ratio = aspect_ratio
+ params.enabled = enable and params.depth <= max_depth
diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py
new file mode 100644
index 00000000..395d584b
--- /dev/null
+++ b/extensions-builtin/hypertile/scripts/hypertile_script.py
@@ -0,0 +1,109 @@
+import hypertile
+from modules import scripts, script_callbacks, shared
+from scripts.hypertile_xyz import add_axis_options
+
+
+class ScriptHypertile(scripts.Script):
+ name = "Hypertile"
+
+ def title(self):
+ return self.name
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def process(self, p, *args):
+ hypertile.set_hypertile_seed(p.all_seeds[0])
+
+ configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
+
+ self.add_infotext(p)
+
+ def before_hr(self, p, *args):
+
+ enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
+
+ # exclusive hypertile seed for the second pass
+ if enable:
+ hypertile.set_hypertile_seed(p.all_seeds[0])
+
+ configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
+
+ if enable and not shared.opts.hypertile_enable_unet:
+ p.extra_generation_params["Hypertile U-Net second pass"] = True
+
+ self.add_infotext(p, add_unet_params=True)
+
+ def add_infotext(self, p, add_unet_params=False):
+ def option(name):
+ value = getattr(shared.opts, name)
+ default_value = shared.opts.get_default(name)
+ return None if value == default_value else value
+
+ if shared.opts.hypertile_enable_unet:
+ p.extra_generation_params["Hypertile U-Net"] = True
+
+ if shared.opts.hypertile_enable_unet or add_unet_params:
+ p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
+ p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
+ p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
+
+ if shared.opts.hypertile_enable_vae:
+ p.extra_generation_params["Hypertile VAE"] = True
+ p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
+ p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
+ p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
+
+
+def configure_hypertile(width, height, enable_unet=True):
+ hypertile.hypertile_hook_model(
+ shared.sd_model.first_stage_model,
+ width,
+ height,
+ swap_size=shared.opts.hypertile_swap_size_vae,
+ max_depth=shared.opts.hypertile_max_depth_vae,
+ tile_size_max=shared.opts.hypertile_max_tile_vae,
+ enable=shared.opts.hypertile_enable_vae,
+ )
+
+ hypertile.hypertile_hook_model(
+ shared.sd_model.model,
+ width,
+ height,
+ swap_size=shared.opts.hypertile_swap_size_unet,
+ max_depth=shared.opts.hypertile_max_depth_unet,
+ tile_size_max=shared.opts.hypertile_max_tile_unet,
+ enable=enable_unet,
+ is_sdxl=shared.sd_model.is_sdxl
+ )
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ options = {
+ "hypertile_explanation": shared.OptionHTML("""
+ <a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
+ resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
+ benefit.
+ """),
+
+ "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
+ "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
+ "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
+ "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
+ "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
+
+ "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
+ "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
+ "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
+ "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
+ }
+
+ for name, opt in options.items():
+ opt.section = ('hypertile', "Hypertile")
+ shared.opts.add_option(name, opt)
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
+script_callbacks.on_before_ui(add_axis_options)
diff --git a/extensions-builtin/hypertile/scripts/hypertile_xyz.py b/extensions-builtin/hypertile/scripts/hypertile_xyz.py
new file mode 100644
index 00000000..9e96ae3c
--- /dev/null
+++ b/extensions-builtin/hypertile/scripts/hypertile_xyz.py
@@ -0,0 +1,51 @@
+from modules import scripts
+from modules.shared import opts
+
+xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
+
+def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
+ """
+ Returns a function that applies the given value to the given value_name in opts.data.
+ """
+ def validate(value_name:str, value:str):
+ value = int(value)
+ # validate value
+ if not min_range == -1:
+ assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
+ if not max_range == -1:
+ assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
+ def apply_int(p, x, xs):
+ validate(value_name, x)
+ opts.data[value_name] = int(x)
+ return apply_int
+
+def bool_applier(value_name:str):
+ """
+ Returns a function that applies the given value to the given value_name in opts.data.
+ """
+ def validate(value_name:str, value:str):
+ assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
+ def apply_bool(p, x, xs):
+ validate(value_name, x)
+ value_boolean = x.lower() == "true"
+ opts.data[value_name] = value_boolean
+ return apply_bool
+
+def add_axis_options():
+ extra_axis_options = [
+ xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)),
+ xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)),
+ xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]),
+ xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)),
+ xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)),
+ xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)),
+ xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]),
+ xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)),
+ xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)),
+ ]
+ set_a = {opt.label for opt in xyz_grid.axis_options}
+ set_b = {opt.label for opt in extra_axis_options}
+ if set_a.intersection(set_b):
+ return
+
+ xyz_grid.axis_options.extend(extra_axis_options)
diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js
index 652f07ac..bff1aced 100644
--- a/extensions-builtin/mobile/javascript/mobile.js
+++ b/extensions-builtin/mobile/javascript/mobile.js
@@ -12,6 +12,8 @@ function isMobile() {
}
function reportWindowSize() {
+ if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
+
var currentlyMobile = isMobile();
if (currentlyMobile == isSetupForMobile) return;
isSetupForMobile = currentlyMobile;
diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
new file mode 100644
index 00000000..d9024344
--- /dev/null
+++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
@@ -0,0 +1,747 @@
+import numpy as np
+import gradio as gr
+import math
+from modules.ui_components import InputAccordion
+import modules.scripts as scripts
+
+
+class SoftInpaintingSettings:
+ def __init__(self,
+ mask_blend_power,
+ mask_blend_scale,
+ inpaint_detail_preservation,
+ composite_mask_influence,
+ composite_difference_threshold,
+ composite_difference_contrast):
+ self.mask_blend_power = mask_blend_power
+ self.mask_blend_scale = mask_blend_scale
+ self.inpaint_detail_preservation = inpaint_detail_preservation
+ self.composite_mask_influence = composite_mask_influence
+ self.composite_difference_threshold = composite_difference_threshold
+ self.composite_difference_contrast = composite_difference_contrast
+
+ def add_generation_params(self, dest):
+ dest[enabled_gen_param_label] = True
+ dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
+ dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
+ dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
+ dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
+ dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
+ dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
+
+
+# ------------------- Methods -------------------
+
+def processing_uses_inpainting(p):
+ # TODO: Figure out a better way to determine if inpainting is being used by p
+ if getattr(p, "image_mask", None) is not None:
+ return True
+
+ if getattr(p, "mask", None) is not None:
+ return True
+
+ if getattr(p, "nmask", None) is not None:
+ return True
+
+ return False
+
+
+def latent_blend(settings, a, b, t):
+ """
+ Interpolates two latent image representations according to the parameter t,
+ where the interpolated vectors' magnitudes are also interpolated separately.
+ The "detail_preservation" factor biases the magnitude interpolation towards
+ the larger of the two magnitudes.
+ """
+ import torch
+
+ # NOTE: We use inplace operations wherever possible.
+
+ # [4][w][h] to [1][4][w][h]
+ t2 = t.unsqueeze(0)
+ # [4][w][h] to [1][1][w][h] - the [4] seem redundant.
+ t3 = t[0].unsqueeze(0).unsqueeze(0)
+
+ one_minus_t2 = 1 - t2
+ one_minus_t3 = 1 - t3
+
+ # Linearly interpolate the image vectors.
+ a_scaled = a * one_minus_t2
+ b_scaled = b * t2
+ image_interp = a_scaled
+ image_interp.add_(b_scaled)
+ result_type = image_interp.dtype
+ del a_scaled, b_scaled, t2, one_minus_t2
+
+ # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
+ # 64-bit operations are used here to allow large exponents.
+ current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
+
+ # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+ a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
+ settings.inpaint_detail_preservation) * one_minus_t3
+ b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
+ settings.inpaint_detail_preservation) * t3
+ desired_magnitude = a_magnitude
+ desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
+ del a_magnitude, b_magnitude, t3, one_minus_t3
+
+ # Change the linearly interpolated image vectors' magnitudes to the value we want.
+ # This is the last 64-bit operation.
+ image_interp_scaling_factor = desired_magnitude
+ image_interp_scaling_factor.div_(current_magnitude)
+ image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
+ image_interp_scaled = image_interp
+ image_interp_scaled.mul_(image_interp_scaling_factor)
+ del current_magnitude
+ del desired_magnitude
+ del image_interp
+ del image_interp_scaling_factor
+ del result_type
+
+ return image_interp_scaled
+
+
+def get_modified_nmask(settings, nmask, sigma):
+ """
+ Converts a negative mask representing the transparency of the original latent vectors being overlayed
+ to a mask that is scaled according to the denoising strength for this step.
+
+ Where:
+ 0 = fully opaque, infinite density, fully masked
+ 1 = fully transparent, zero density, fully unmasked
+
+ We bring this transparency to a power, as this allows one to simulate N number of blending operations
+ where N can be any positive real value. Using this one can control the balance of influence between
+ the denoiser and the original latents according to the sigma value.
+
+ NOTE: "mask" is not used
+ """
+ import torch
+ return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
+
+
+def apply_adaptive_masks(
+ settings: SoftInpaintingSettings,
+ nmask,
+ latent_orig,
+ latent_processed,
+ overlay_images,
+ width, height,
+ paste_to):
+ import torch
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
+ latent_mask = nmask[0].float()
+ # convert the original mask into a form we use to scale distances for thresholding
+ mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
+ mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
+ + mask_scalar * settings.composite_mask_influence)
+ mask_scalar = mask_scalar / (1.00001 - mask_scalar)
+ mask_scalar = mask_scalar.cpu().numpy()
+
+ latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
+
+ kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
+
+ masks_for_overlay = []
+
+ for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
+ converted_mask = distance_map.float().cpu().numpy()
+ converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.9, percentile_max=1, min_width=1)
+ converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.25, percentile_max=0.75, min_width=1)
+
+ # The distance at which opacity of original decreases to 50%
+ half_weighted_distance = settings.composite_difference_threshold * mask_scalar
+ converted_mask = converted_mask / half_weighted_distance
+
+ converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
+ converted_mask = smootherstep(converted_mask)
+ converted_mask = 1 - converted_mask
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc.uncrop(converted_mask,
+ (overlay_image.width, overlay_image.height),
+ paste_to)
+
+ masks_for_overlay.append(converted_mask)
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
+ return masks_for_overlay
+
+
+def apply_masks(
+ settings,
+ nmask,
+ overlay_images,
+ width, height,
+ paste_to):
+ import torch
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ converted_mask = nmask[0].float()
+ converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc.uncrop(converted_mask,
+ (width, height),
+ paste_to)
+
+ masks_for_overlay = []
+
+ for i, overlay_image in enumerate(overlay_images):
+ masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
+ return masks_for_overlay
+
+
+def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
+ """
+ Generalization convolution filter capable of applying
+ weighted mean, median, maximum, and minimum filters
+ parametrically using an arbitrary kernel.
+
+ Args:
+ img (nparray):
+ The image, a 2-D array of floats, to which the filter is being applied.
+ kernel (nparray):
+ The kernel, a 2-D array of floats.
+ kernel_center (nparray):
+ The kernel center coordinate, a 1-D array with two elements.
+ percentile_min (float):
+ The lower bound of the histogram window used by the filter,
+ from 0 to 1.
+ percentile_max (float):
+ The upper bound of the histogram window used by the filter,
+ from 0 to 1.
+ min_width (float):
+ The minimum size of the histogram window bounds, in weight units.
+ Must be greater than 0.
+
+ Returns:
+ (nparray): A filtered copy of the input image "img", a 2-D array of floats.
+ """
+
+ # Converts an index tuple into a vector.
+ def vec(x):
+ return np.array(x)
+
+ kernel_min = -kernel_center
+ kernel_max = vec(kernel.shape) - kernel_center
+
+ def weighted_histogram_filter_single(idx):
+ idx = vec(idx)
+ min_index = np.maximum(0, idx + kernel_min)
+ max_index = np.minimum(vec(img.shape), idx + kernel_max)
+ window_shape = max_index - min_index
+
+ class WeightedElement:
+ """
+ An element of the histogram, its weight
+ and bounds.
+ """
+
+ def __init__(self, value, weight):
+ self.value: float = value
+ self.weight: float = weight
+ self.window_min: float = 0.0
+ self.window_max: float = 1.0
+
+ # Collect the values in the image as WeightedElements,
+ # weighted by their corresponding kernel values.
+ values = []
+ for window_tup in np.ndindex(tuple(window_shape)):
+ window_index = vec(window_tup)
+ image_index = window_index + min_index
+ centered_kernel_index = image_index - idx
+ kernel_index = centered_kernel_index + kernel_center
+ element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
+ values.append(element)
+
+ def sort_key(x: WeightedElement):
+ return x.value
+
+ values.sort(key=sort_key)
+
+ # Calculate the height of the stack (sum)
+ # and each sample's range they occupy in the stack
+ sum = 0
+ for i in range(len(values)):
+ values[i].window_min = sum
+ sum += values[i].weight
+ values[i].window_max = sum
+
+ # Calculate what range of this stack ("window")
+ # we want to get the weighted average across.
+ window_min = sum * percentile_min
+ window_max = sum * percentile_max
+ window_width = window_max - window_min
+
+ # Ensure the window is within the stack and at least a certain size.
+ if window_width < min_width:
+ window_center = (window_min + window_max) / 2
+ window_min = window_center - min_width / 2
+ window_max = window_center + min_width / 2
+
+ if window_max > sum:
+ window_max = sum
+ window_min = sum - min_width
+
+ if window_min < 0:
+ window_min = 0
+ window_max = min_width
+
+ value = 0
+ value_weight = 0
+
+ # Get the weighted average of all the samples
+ # that overlap with the window, weighted
+ # by the size of their overlap.
+ for i in range(len(values)):
+ if window_min >= values[i].window_max:
+ continue
+ if window_max <= values[i].window_min:
+ break
+
+ s = max(window_min, values[i].window_min)
+ e = min(window_max, values[i].window_max)
+ w = e - s
+
+ value += values[i].value * w
+ value_weight += w
+
+ return value / value_weight if value_weight != 0 else 0
+
+ img_out = img.copy()
+
+ # Apply the kernel operation over each pixel.
+ for index in np.ndindex(img.shape):
+ img_out[index] = weighted_histogram_filter_single(index)
+
+ return img_out
+
+
+def smoothstep(x):
+ """
+ The smoothstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * (3 - 2 * x)
+
+
+def smootherstep(x):
+ """
+ The smootherstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * x * (x * (6 * x - 15) + 10)
+
+
+def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
+ """
+ Creates a Gaussian kernel with thresholded edges.
+
+ Args:
+ stddev_radius (float):
+ Standard deviation of the gaussian kernel, in pixels.
+ max_radius (int):
+ The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
+ The kernel is thresholded so that any values one pixel beyond this radius
+ is weighted at 0.
+
+ Returns:
+ (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
+ """
+
+ # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
+ def gaussian(sqr_mag):
+ return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
+
+ # Helper function for converting a tuple to an array.
+ def vec(x):
+ return np.array(x)
+
+ """
+ Since a gaussian is unbounded, we need to limit ourselves
+ to a finite range.
+ We taper the ends off at the end of that range so they equal zero
+ while preserving the maximum value of 1 at the mean.
+ """
+ zero_radius = max_radius + 1.0
+ gauss_zero = gaussian(zero_radius * zero_radius)
+ gauss_kernel_scale = 1 / (1 - gauss_zero)
+
+ def gaussian_kernel_func(coordinate):
+ x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
+ x = gaussian(x)
+ x -= gauss_zero
+ x *= gauss_kernel_scale
+ x = max(0.0, x)
+ return x
+
+ size = max_radius * 2 + 1
+ kernel_center = max_radius
+ kernel = np.zeros((size, size))
+
+ for index in np.ndindex(kernel.shape):
+ kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
+
+ return kernel, kernel_center
+
+
+# ------------------- Constants -------------------
+
+
+default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
+
+enabled_ui_label = "Soft inpainting"
+enabled_gen_param_label = "Soft inpainting enabled"
+enabled_el_id = "soft_inpainting_enabled"
+
+ui_labels = SoftInpaintingSettings(
+ "Schedule bias",
+ "Preservation strength",
+ "Transition contrast boost",
+ "Mask influence",
+ "Difference threshold",
+ "Difference contrast")
+
+ui_info = SoftInpaintingSettings(
+ "Shifts when preservation of original content occurs during denoising.",
+ "How strongly partially masked content should be preserved.",
+ "Amplifies the contrast that may be lost in partially masked regions.",
+ "How strongly the original mask should bias the difference threshold.",
+ "How much an image region can change before the original pixels are not blended in anymore.",
+ "How sharp the transition should be between blended and not blended.")
+
+gen_param_labels = SoftInpaintingSettings(
+ "Soft inpainting schedule bias",
+ "Soft inpainting preservation strength",
+ "Soft inpainting transition contrast boost",
+ "Soft inpainting mask influence",
+ "Soft inpainting difference threshold",
+ "Soft inpainting difference contrast")
+
+el_ids = SoftInpaintingSettings(
+ "mask_blend_power",
+ "mask_blend_scale",
+ "inpaint_detail_preservation",
+ "composite_mask_influence",
+ "composite_difference_threshold",
+ "composite_difference_contrast")
+
+
+# ------------------- Script -------------------
+
+
+class Script(scripts.Script):
+ def __init__(self):
+ self.section = "inpaint"
+ self.masks_for_overlay = None
+ self.overlay_images = None
+
+ def title(self):
+ return "Soft Inpainting"
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible if is_img2img else False
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return
+
+ with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
+ with gr.Group():
+ gr.Markdown(
+ """
+ Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
+ **High _Mask blur_** values are recommended!
+ """)
+
+ power = \
+ gr.Slider(label=ui_labels.mask_blend_power,
+ info=ui_info.mask_blend_power,
+ minimum=0,
+ maximum=8,
+ step=0.1,
+ value=default.mask_blend_power,
+ elem_id=el_ids.mask_blend_power)
+ scale = \
+ gr.Slider(label=ui_labels.mask_blend_scale,
+ info=ui_info.mask_blend_scale,
+ minimum=0,
+ maximum=8,
+ step=0.05,
+ value=default.mask_blend_scale,
+ elem_id=el_ids.mask_blend_scale)
+ detail = \
+ gr.Slider(label=ui_labels.inpaint_detail_preservation,
+ info=ui_info.inpaint_detail_preservation,
+ minimum=1,
+ maximum=32,
+ step=0.5,
+ value=default.inpaint_detail_preservation,
+ elem_id=el_ids.inpaint_detail_preservation)
+
+ gr.Markdown(
+ """
+ ### Pixel Composite Settings
+ """)
+
+ mask_inf = \
+ gr.Slider(label=ui_labels.composite_mask_influence,
+ info=ui_info.composite_mask_influence,
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=default.composite_mask_influence,
+ elem_id=el_ids.composite_mask_influence)
+
+ dif_thresh = \
+ gr.Slider(label=ui_labels.composite_difference_threshold,
+ info=ui_info.composite_difference_threshold,
+ minimum=0,
+ maximum=8,
+ step=0.25,
+ value=default.composite_difference_threshold,
+ elem_id=el_ids.composite_difference_threshold)
+
+ dif_contr = \
+ gr.Slider(label=ui_labels.composite_difference_contrast,
+ info=ui_info.composite_difference_contrast,
+ minimum=0,
+ maximum=8,
+ step=0.25,
+ value=default.composite_difference_contrast,
+ elem_id=el_ids.composite_difference_contrast)
+
+ with gr.Accordion("Help", open=False):
+ gr.Markdown(
+ f"""
+ ### {ui_labels.mask_blend_power}
+
+ The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
+ This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
+ This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
+
+ - **Below 1**: Stronger preservation near the end (with low sigma)
+ - **1**: Balanced (proportional to sigma)
+ - **Above 1**: Stronger preservation in the beginning (with high sigma)
+ """)
+ gr.Markdown(
+ f"""
+ ### {ui_labels.mask_blend_scale}
+
+ Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
+ This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
+
+ - **Low values**: Favors generated content.
+ - **High values**: Favors original content.
+ """)
+ gr.Markdown(
+ f"""
+ ### {ui_labels.inpaint_detail_preservation}
+
+ This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
+ With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
+ This can prevent the loss of contrast that occurs with linear interpolation.
+
+ - **Low values**: Softer blending, details may fade.
+ - **High values**: Stronger contrast, may over-saturate colors.
+ """)
+
+ gr.Markdown(
+ """
+ ## Pixel Composite Settings
+
+ Masks are generated based on how much a part of the image changed after denoising.
+ These masks are used to blend the original and final images together.
+ If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_mask_influence}
+
+ This parameter controls how much the mask should bias this sensitivity to difference.
+
+ - **0**: Ignore the mask, only consider differences in image content.
+ - **1**: Follow the mask closely despite image content changes.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_difference_threshold}
+
+ This value represents the difference at which the original pixels will have less than 50% opacity.
+
+ - **Low values**: Two images patches must be almost the same in order to retain original pixels.
+ - **High values**: Two images patches can be very different and still retain original pixels.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_difference_contrast}
+
+ This value represents the contrast between the opacity of the original and inpainted content.
+
+ - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.
+ - **High values**: Ghosting will be less common, but transitions may be very sudden.
+ """)
+
+ self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
+ (power, gen_param_labels.mask_blend_power),
+ (scale, gen_param_labels.mask_blend_scale),
+ (detail, gen_param_labels.inpaint_detail_preservation),
+ (mask_inf, gen_param_labels.composite_mask_influence),
+ (dif_thresh, gen_param_labels.composite_difference_threshold),
+ (dif_contr, gen_param_labels.composite_difference_contrast)]
+
+ self.paste_field_names = []
+ for _, field_name in self.infotext_fields:
+ self.paste_field_names.append(field_name)
+
+ return [soft_inpainting_enabled,
+ power,
+ scale,
+ detail,
+ mask_inf,
+ dif_thresh,
+ dif_contr]
+
+ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ # Shut off the rounding it normally does.
+ p.mask_round = False
+
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
+
+ # p.extra_generation_params["Mask rounding"] = False
+ settings.add_generation_params(p.extra_generation_params)
+
+ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
+ dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ if mba.is_final_blend:
+ mba.blended_latent = mba.current_latent
+ return
+
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
+
+ # todo: Why is sigma 2D? Both values are the same.
+ mba.blended_latent = latent_blend(settings,
+ mba.init_latent,
+ mba.current_latent,
+ get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
+
+ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
+ dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ nmask = getattr(p, "nmask", None)
+ if nmask is None:
+ return
+
+ from modules import images
+ from modules.shared import opts
+
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
+
+ # since the original code puts holes in the existing overlay images,
+ # we have to rebuild them.
+ self.overlay_images = []
+ for img in p.init_images:
+
+ image = images.flatten(img, opts.img2img_background_color)
+
+ if p.paste_to is None and p.resize_mode != 3:
+ image = images.resize_image(p.resize_mode, image, p.width, p.height)
+
+ self.overlay_images.append(image.convert('RGBA'))
+
+ if len(p.init_images) == 1:
+ self.overlay_images = self.overlay_images * p.batch_size
+
+ if getattr(ps.samples, 'already_decoded', False):
+ self.masks_for_overlay = apply_masks(settings=settings,
+ nmask=nmask,
+ overlay_images=self.overlay_images,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
+ else:
+ self.masks_for_overlay = apply_adaptive_masks(settings=settings,
+ nmask=nmask,
+ latent_orig=p.init_latent,
+ latent_processed=ps.samples,
+ overlay_images=self.overlay_images,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
+
+ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
+ detail_preservation, mask_inf, dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ if self.masks_for_overlay is None:
+ return
+
+ if self.overlay_images is None:
+ return
+
+ ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
+ ppmo.overlay_image = self.overlay_images[ppmo.index]
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index 04464100..688c2f11 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -28,7 +28,7 @@ function keyupEditAttention(event) {
if (afterParen == -1) return false;
let afterOpeningParen = after.indexOf(OPEN);
- if (afterOpeningParen != -1 && afterOpeningParen < beforeParen) return false;
+ if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
// Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index ac26718f..f1ad19a6 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
+ var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
+ var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
- sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(searchDiv);
tabs.appendChild(sort);
tabs.appendChild(sortOrder);
@@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
elem.style.display = visible ? "" : "none";
});
+
+ applySort();
};
var applySort = function() {
+ var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
+
var reverse = sortOrder.classList.contains("sortReverse");
- var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
- sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
- var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
- if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
+ var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
+ sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
+ var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
+
+ if (sortKeyStore == sort.dataset.sortkey) {
return;
}
-
sort.dataset.sortkey = sortKeyStore;
- var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
cards.forEach(function(card) {
card.originalParentElement = card.parentElement;
});
@@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
};
search.addEventListener("input", applyFilter);
- applyFilter();
- ["change", "blur", "click"].forEach(function(evt) {
- sort.querySelector("input").addEventListener(evt, applySort);
- });
sortOrder.addEventListener("click", function() {
sortOrder.classList.toggle("sortReverse");
applySort();
});
+ applyFilter();
+ extraNetworksApplySort[tabname] = applySort;
extraNetworksApplyFilter[tabname] = applyFilter;
var showDirsUpdate = function() {
@@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) {
showDirsUpdate();
}
+function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
+ if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
+
+ var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
+ var prompt = gradioApp().getElementById(tabname + '_prompt_row');
+ var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
+ var elem = id ? gradioApp().getElementById(id) : null;
+
+ if (showNegativePrompt && elem) {
+ elem.insertBefore(negPrompt, elem.firstChild);
+ } else {
+ promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
+ }
+
+ if (showPrompt && elem) {
+ elem.insertBefore(prompt, elem.firstChild);
+ } else {
+ promptContainer.insertBefore(prompt, promptContainer.firstChild);
+ }
+
+ if (elem) {
+ elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
+ }
+}
+
+
+function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
+ extraNetworksMovePromptToTab(tabname, '', false, false);
+}
+
+function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
+ extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
+
+}
+
function applyExtraNetworkFilter(tabname) {
setTimeout(extraNetworksApplyFilter[tabname], 1);
}
+function applyExtraNetworkSort(tabname) {
+ setTimeout(extraNetworksApplySort[tabname], 1);
+}
+
var extraNetworksApplyFilter = {};
+var extraNetworksApplySort = {};
var activePromptTextarea = {};
function setupExtraNetworks() {
@@ -143,8 +185,10 @@ onUiLoaded(setupExtraNetworks);
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
-function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
- var m = text.match(re_extranet);
+var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
+var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
+function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
+ var m = text.match(isNeg ? re_extranet_neg : re_extranet);
var replaced = false;
var newTextareaText;
if (m) {
@@ -152,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var extraTextAfterNet = m[2];
var partToSearch = m[1];
var foundAtPosition = -1;
- newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
- m = found.match(re_extranet);
+ newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
+ m = found.match(isNeg ? re_extranet_neg : re_extranet);
if (m[1] == partToSearch) {
replaced = true;
foundAtPosition = pos;
@@ -163,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
});
if (foundAtPosition >= 0) {
- if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
+ if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
}
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
@@ -188,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return false;
}
-function cardClicked(tabname, textToAdd, allowNegativePrompt) {
- var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
+function updatePromptArea(text, textArea, isNeg) {
- if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
- textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
+ if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
+ textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
}
- updateInput(textarea);
+ updateInput(textArea);
+}
+
+function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
+ if (textToAddNegative.length > 0) {
+ updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"));
+ updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true);
+ } else {
+ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
+ updatePromptArea(textToAdd, textarea);
+ }
}
function saveCardPreview(event, tabname, filename) {
@@ -350,3 +403,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
}
});
}
+
+window.addEventListener("keydown", function(event) {
+ if (event.key == "Escape") {
+ closePopup();
+ }
+});
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index e4dae91b..625c5d14 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -34,7 +34,7 @@ function updateOnBackgroundChange() {
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
let preview = gradioApp().querySelectorAll('.livePreview > img');
- if (preview.length > 0) {
+ if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
// show preview image if available
modalImage.src = preview[preview.length - 1].src;
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js
index f2839852..7570309a 100644
--- a/javascript/inputAccordion.js
+++ b/javascript/inputAccordion.js
@@ -1,37 +1,68 @@
-var observerAccordionOpen = new MutationObserver(function(mutations) {
- mutations.forEach(function(mutationRecord) {
- var elem = mutationRecord.target;
- var open = elem.classList.contains('open');
+function inputAccordionChecked(id, checked) {
+ var accordion = gradioApp().getElementById(id);
+ accordion.visibleCheckbox.checked = checked;
+ accordion.onVisibleCheckboxChange();
+}
- var accordion = elem.parentNode;
- accordion.classList.toggle('input-accordion-open', open);
+function setupAccordion(accordion) {
+ var labelWrap = accordion.querySelector('.label-wrap');
+ var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
+ var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
+ var span = labelWrap.querySelector('span');
+ var linked = true;
- var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
- checkbox.checked = open;
- updateInput(checkbox);
+ var isOpen = function() {
+ return labelWrap.classList.contains('open');
+ };
- var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
- if (extra) {
- extra.style.display = open ? "" : "none";
- }
+ var observerAccordionOpen = new MutationObserver(function(mutations) {
+ mutations.forEach(function(mutationRecord) {
+ accordion.classList.toggle('input-accordion-open', isOpen());
+
+ if (linked) {
+ accordion.visibleCheckbox.checked = isOpen();
+ accordion.onVisibleCheckboxChange();
+ }
+ });
});
-});
+ observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
-function inputAccordionChecked(id, checked) {
- var label = gradioApp().querySelector('#' + id + " .label-wrap");
- if (label.classList.contains('open') != checked) {
- label.click();
+ if (extra) {
+ labelWrap.insertBefore(extra, labelWrap.lastElementChild);
}
+
+ accordion.onChecked = function(checked) {
+ if (isOpen() != checked) {
+ labelWrap.click();
+ }
+ };
+
+ var visibleCheckbox = document.createElement('INPUT');
+ visibleCheckbox.type = 'checkbox';
+ visibleCheckbox.checked = isOpen();
+ visibleCheckbox.id = accordion.id + "-visible-checkbox";
+ visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
+ span.insertBefore(visibleCheckbox, span.firstChild);
+
+ accordion.visibleCheckbox = visibleCheckbox;
+ accordion.onVisibleCheckboxChange = function() {
+ if (linked && isOpen() != visibleCheckbox.checked) {
+ labelWrap.click();
+ }
+
+ gradioCheckbox.checked = visibleCheckbox.checked;
+ updateInput(gradioCheckbox);
+ };
+
+ visibleCheckbox.addEventListener('click', function(event) {
+ linked = false;
+ event.stopPropagation();
+ });
+ visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
}
onUiLoaded(function() {
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
- var labelWrap = accordion.querySelector('.label-wrap');
- observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
-
- var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
- if (extra) {
- labelWrap.insertBefore(extra, labelWrap.lastElementChild);
- }
+ setupAccordion(accordion);
}
});
diff --git a/javascript/notification.js b/javascript/notification.js
index 6d799561..3ee972ae 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
lastHeadImg = headImg;
// play notification sound if available
- gradioApp().querySelector('#audio_notification audio')?.play();
+ const notificationAudio = gradioApp().querySelector('#audio_notification audio');
+ if (notificationAudio) {
+ notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
+ notificationAudio.play();
+ }
if (document.hasFocus()) return;
diff --git a/javascript/settings.js b/javascript/settings.js
index 4e79ec00..e6009290 100644
--- a/javascript/settings.js
+++ b/javascript/settings.js
@@ -44,3 +44,28 @@ onUiLoaded(function() {
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
});
+
+
+onOptionsChanged(function() {
+ if (gradioApp().querySelector('#settings .settings-category')) return;
+
+ var sectionMap = {};
+ gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
+ sectionMap[x.textContent.trim()] = x;
+ });
+
+ opts._categories.forEach(function(x) {
+ var section = x[0];
+ var category = x[1];
+
+ var span = document.createElement('SPAN');
+ span.textContent = category;
+ span.className = 'settings-category';
+
+ var sectionElem = sectionMap[section];
+ if (!sectionElem) return;
+
+ sectionElem.parentElement.insertBefore(span, sectionElem);
+ });
+});
+
diff --git a/javascript/ui.js b/javascript/ui.js
index 2e262602..18c9f891 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -170,6 +170,23 @@ function submit_img2img() {
return res;
}
+function submit_extras() {
+ showSubmitButtons('extras', false);
+
+ var id = randomId();
+
+ requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
+ showSubmitButtons('extras', true);
+ });
+
+ var res = create_submit_args(arguments);
+
+ res[0] = id;
+
+ console.log(res);
+ return res;
+}
+
function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false);
var id = localGet("txt2img_task_id");
@@ -198,9 +215,33 @@ function restoreProgressImg2img() {
}
+/**
+ * Configure the width and height elements on `tabname` to accept
+ * pasting of resolutions in the form of "width x height".
+ */
+function setupResolutionPasting(tabname) {
+ var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
+ var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
+ for (const el of [width, height]) {
+ el.addEventListener('paste', function(event) {
+ var pasteData = event.clipboardData.getData('text/plain');
+ var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
+ if (parsed) {
+ width.value = parsed[1];
+ height.value = parsed[2];
+ updateInput(width);
+ updateInput(height);
+ event.preventDefault();
+ }
+ });
+ }
+}
+
onUiLoaded(function() {
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
+ setupResolutionPasting('txt2img');
+ setupResolutionPasting('img2img');
});
diff --git a/modules/api/api.py b/modules/api/api.py
index 09083874..0e2807de 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,12 +17,11 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
-from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
@@ -32,7 +31,7 @@ from typing import Any
import piexif
import piexif.helper
from contextlib import closing
-
+from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
def script_name_to_index(name, scripts):
try:
@@ -235,7 +234,6 @@ class Api:
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
- self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
@@ -253,6 +251,24 @@ class Api:
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
+ txt2img_script_runner = scripts.scripts_txt2img
+ img2img_script_runner = scripts.scripts_img2img
+
+ if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
+ ui.create_ui()
+
+ if not txt2img_script_runner.scripts:
+ txt2img_script_runner.initialize_scripts(False)
+ if not self.default_script_arg_txt2img:
+ self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
+
+ if not img2img_script_runner.scripts:
+ img2img_script_runner.initialize_scripts(True)
+ if not self.default_script_arg_img2img:
+ self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
+
+
+
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@@ -314,8 +330,13 @@ class Api:
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
- def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
+ def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
script_args = default_script_args.copy()
+
+ if input_script_args is not None:
+ for index, value in input_script_args.items():
+ script_args[index] = value
+
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
@@ -337,13 +358,83 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
+ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
+ """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
+
+ If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
+
+ Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
+ """
+
+ if not request.infotext:
+ return {}
+
+ possible_fields = infotext.paste_fields[tabname]["fields"]
+ set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
+ params = infotext.parse_generation_parameters(request.infotext)
+
+ def get_field_value(field, params):
+ value = field.function(params) if field.function else params.get(field.label)
+ if value is None:
+ return None
+
+ if field.api in request.__fields__:
+ target_type = request.__fields__[field.api].type_
+ else:
+ target_type = type(field.component.value)
+
+ if target_type == type(None):
+ return None
+
+ if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
+ value = value.get('value')
+
+ if value is not None and not isinstance(value, target_type):
+ value = target_type(value)
+
+ return value
+
+ for field in possible_fields:
+ if not field.api:
+ continue
+
+ if field.api in set_fields:
+ continue
+
+ value = get_field_value(field, params)
+ if value is not None:
+ setattr(request, field.api, value)
+
+ if request.override_settings is None:
+ request.override_settings = {}
+
+ overriden_settings = infotext.get_override_settings(params)
+ for _, setting_name, value in overriden_settings:
+ if setting_name not in request.override_settings:
+ request.override_settings[setting_name] = value
+
+ if script_runner is not None and mentioned_script_args is not None:
+ indexes = {v: i for i, v in enumerate(script_runner.inputs)}
+ script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
+
+ for field, index in script_fields:
+ value = get_field_value(field, params)
+
+ if value is None:
+ continue
+
+ mentioned_script_args[index] = value
+
+ return params
+
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
+ task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
+
script_runner = scripts.scripts_txt2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
- if not self.default_script_arg_txt2img:
- self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
+
+ infotext_script_args = {}
+ self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -358,12 +449,15 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
+ add_task_to_queue(task_id)
+
with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True
@@ -373,12 +467,14 @@ class Api:
try:
shared.state.begin(job="scripts_txt2img")
+ start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
+ finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
@@ -388,6 +484,8 @@ class Api:
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
+ task_id = img2imgreq.force_task_id or create_task_id("img2img")
+
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -397,11 +495,10 @@ class Api:
mask = decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(True)
- ui.create_ui()
- if not self.default_script_arg_img2img:
- self.default_script_arg_img2img = self.init_default_script_args(script_runner)
+
+ infotext_script_args = {}
+ self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -418,12 +515,15 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
+ add_task_to_queue(task_id)
+
with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
@@ -434,12 +534,14 @@ class Api:
try:
shared.state.begin(job="scripts_img2img")
+ start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
+ finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
@@ -482,7 +584,7 @@ class Api:
if geninfo is None:
geninfo = ""
- params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
+ params = infotext.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
@@ -513,7 +615,7 @@ class Api:
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
- return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
@@ -675,19 +777,6 @@ class Api:
finally:
shared.state.end()
- def preprocess(self, args: dict):
- try:
- shared.state.begin(job="preprocess")
- preprocess(**args) # quick operation unless blip/booru interrogation is enabled
- shared.state.end()
- return models.PreprocessResponse(info='preprocess complete')
- except KeyError as e:
- return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
- except Exception as e:
- return models.PreprocessResponse(info=f"preprocess error: {e}")
- finally:
- shared.state.end()
-
def train_embedding(self, args: dict):
try:
shared.state.begin(job="train_embedding")
diff --git a/modules/api/models.py b/modules/api/models.py
index a0d80af8..16edf11c 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
@@ -202,9 +206,6 @@ class TrainResponse(BaseModel):
class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
-class PreprocessResponse(BaseModel):
- info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
-
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
diff --git a/modules/cache.py b/modules/cache.py
index ff26a213..2d37e7b9 100644
--- a/modules/cache.py
+++ b/modules/cache.py
@@ -32,7 +32,7 @@ def dump_cache():
with cache_lock:
cache_filename_tmp = cache_filename + "-"
with open(cache_filename_tmp, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
+ json.dump(cache_data, file, indent=4, ensure_ascii=False)
os.replace(cache_filename_tmp, cache_filename)
diff --git a/modules/call_queue.py b/modules/call_queue.py
index ddf0d573..bcd7c546 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
shared.state.skipped = False
shared.state.interrupted = False
+ shared.state.stopping_generation = False
shared.state.job_count = 0
if not add_stats:
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index c6d4b612..3775e670 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
@@ -109,7 +110,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
-parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
+parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
deleted file mode 100644
index 12db6814..00000000
--- a/modules/codeformer/codeformer_arch.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-import math
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-from typing import Optional
-
-from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def calc_mean_std(feat, eps=1e-5):
- """Calculate mean and std for adaptive_instance_normalization.
-
- Args:
- feat (Tensor): 4D tensor.
- eps (float): A small value added to the variance to avoid
- divide-by-zero. Default: 1e-5.
- """
- size = feat.size()
- assert len(size) == 4, 'The input feature should be 4D tensor.'
- b, c = size[:2]
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
- return feat_mean, feat_std
-
-
-def adaptive_instance_normalization(content_feat, style_feat):
- """Adaptive instance normalization.
-
- Adjust the reference features to have the similar color and illuminations
- as those in the degradate features.
-
- Args:
- content_feat (Tensor): The reference feature.
- style_feat (Tensor): The degradate features.
- """
- size = content_feat.size()
- style_mean, style_std = calc_mean_std(style_feat)
- content_mean, content_std = calc_mean_std(content_feat)
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
-
-
-class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
-
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- def forward(self, x, mask=None):
- if mask is None:
- mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
- not_mask = ~mask
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
-
-def _get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
-
-
-class TransformerSALayer(nn.Module):
- def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
- # Implementation of Feedforward model - MLP
- self.linear1 = nn.Linear(embed_dim, dim_mlp)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_mlp, embed_dim)
-
- self.norm1 = nn.LayerNorm(embed_dim)
- self.norm2 = nn.LayerNorm(embed_dim)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward(self, tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
-
- # self attention
- tgt2 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout1(tgt2)
-
- # ffn
- tgt2 = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout2(tgt2)
- return tgt
-
-class Fuse_sft_block(nn.Module):
- def __init__(self, in_ch, out_ch):
- super().__init__()
- self.encode_enc = ResBlock(2*in_ch, out_ch)
-
- self.scale = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- self.shift = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- def forward(self, enc_feat, dec_feat, w=1):
- enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
- scale = self.scale(enc_feat)
- shift = self.shift(enc_feat)
- residual = w * (dec_feat * scale + shift)
- out = dec_feat + residual
- return out
-
-
-@ARCH_REGISTRY.register()
-class CodeFormer(VQAutoEncoder):
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
- codebook_size=1024, latent_size=256,
- connect_list=('32', '64', '128', '256'),
- fix_modules=('quantize', 'generator')):
- super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
-
- if fix_modules is not None:
- for module in fix_modules:
- for param in getattr(self, module).parameters():
- param.requires_grad = False
-
- self.connect_list = connect_list
- self.n_layers = n_layers
- self.dim_embd = dim_embd
- self.dim_mlp = dim_embd*2
-
- self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
- self.feat_emb = nn.Linear(256, self.dim_embd)
-
- # transformer
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
- for _ in range(self.n_layers)])
-
- # logits_predict head
- self.idx_pred_layer = nn.Sequential(
- nn.LayerNorm(dim_embd),
- nn.Linear(dim_embd, codebook_size, bias=False))
-
- self.channels = {
- '16': 512,
- '32': 256,
- '64': 256,
- '128': 128,
- '256': 128,
- '512': 64,
- }
-
- # after second residual block for > 16, before attn layer for ==16
- self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
- # after first residual block for > 16, before attn layer for ==16
- self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
-
- # fuse_convs_dict
- self.fuse_convs_dict = nn.ModuleDict()
- for f_size in self.connect_list:
- in_ch = self.channels[f_size]
- self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
-
- def _init_weights(self, module):
- if isinstance(module, (nn.Linear, nn.Embedding)):
- module.weight.data.normal_(mean=0.0, std=0.02)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
-
- def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
- # ################### Encoder #####################
- enc_feat_dict = {}
- out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
- for i, block in enumerate(self.encoder.blocks):
- x = block(x)
- if i in out_list:
- enc_feat_dict[str(x.shape[-1])] = x.clone()
-
- lq_feat = x
- # ################# Transformer ###################
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
- pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
- # BCHW -> BC(HW) -> (HW)BC
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
- query_emb = feat_emb
- # Transformer encoder
- for layer in self.ft_layers:
- query_emb = layer(query_emb, query_pos=pos_emb)
-
- # output logits
- logits = self.idx_pred_layer(query_emb) # (hw)bn
- logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
-
- if code_only: # for training stage II
- # logits doesn't need softmax before cross_entropy loss
- return logits, lq_feat
-
- # ################# Quantization ###################
- # if self.training:
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
- # # b(hw)c -> bc(hw) -> bchw
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
- # ------------
- soft_one_hot = F.softmax(logits, dim=2)
- _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
- quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
- # preserve gradients
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
-
- if detach_16:
- quant_feat = quant_feat.detach() # for training stage III
- if adain:
- quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
-
- # ################## Generator ####################
- x = quant_feat
- fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
-
- for i, block in enumerate(self.generator.blocks):
- x = block(x)
- if i in fuse_list: # fuse after i-th block
- f_size = str(x.shape[-1])
- if w>0:
- x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
- out = x
- # logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
deleted file mode 100644
index 09ee6660..00000000
--- a/modules/codeformer/vqgan_arch.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-'''
-VQGAN code, adapted from the original created by the Unleashing Transformers authors:
-https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
-
-'''
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from basicsr.utils import get_root_logger
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-@torch.jit.script
-def swish(x):
- return x*torch.sigmoid(x)
-
-
-# Define VQVAE classes
-class VectorQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, beta):
- super(VectorQuantizer, self).__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
- self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
- self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
-
- def forward(self, z):
- # reshape z -> (batch, height, width, channel) and flatten
- z = z.permute(0, 2, 3, 1).contiguous()
- z_flattened = z.view(-1, self.emb_dim)
-
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
- d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
-
- mean_distance = torch.mean(d)
- # find closest encodings
- # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
- min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
- # [0-1], higher score, higher confidence
- min_encoding_scores = torch.exp(-min_encoding_scores/10)
-
- min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
- min_encodings.scatter_(1, min_encoding_indices, 1)
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
- # compute loss for embedding
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
- # preserve gradients
- z_q = z + (z_q - z).detach()
-
- # perplexity
- e_mean = torch.mean(min_encodings, dim=0)
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return z_q, loss, {
- "perplexity": perplexity,
- "min_encodings": min_encodings,
- "min_encoding_indices": min_encoding_indices,
- "min_encoding_scores": min_encoding_scores,
- "mean_distance": mean_distance
- }
-
- def get_codebook_feat(self, indices, shape):
- # input indices: batch*token_num -> (batch*token_num)*1
- # shape: batch, height, width, channel
- indices = indices.view(-1,1)
- min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
- min_encodings.scatter_(1, indices, 1)
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
-
- if shape is not None: # reshape back to match original input shape
- z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
-
- return z_q
-
-
-class GumbelQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
- super().__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.straight_through = straight_through
- self.temperature = temp_init
- self.kl_weight = kl_weight
- self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
- self.embed = nn.Embedding(codebook_size, emb_dim)
-
- def forward(self, z):
- hard = self.straight_through if self.training else True
-
- logits = self.proj(z)
-
- soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
-
- z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
-
- # + kl divergence to the prior loss
- qy = F.softmax(logits, dim=1)
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
- min_encoding_indices = soft_one_hot.argmax(dim=1)
-
- return z_q, diff, {
- "min_encoding_indices": min_encoding_indices
- }
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
-
- def forward(self, x):
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- return x
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
-
- def forward(self, x):
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
- x = self.conv(x)
-
- return x
-
-
-class ResBlock(nn.Module):
- def __init__(self, in_channels, out_channels=None):
- super(ResBlock, self).__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.norm1 = normalize(in_channels)
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- self.norm2 = normalize(out_channels)
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if self.in_channels != self.out_channels:
- self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
-
- def forward(self, x_in):
- x = x_in
- x = self.norm1(x)
- x = swish(x)
- x = self.conv1(x)
- x = self.norm2(x)
- x = swish(x)
- x = self.conv2(x)
- if self.in_channels != self.out_channels:
- x_in = self.conv_out(x_in)
-
- return x + x_in
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h*w)
- q = q.permute(0, 2, 1)
- k = k.reshape(b, c, h*w)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c)**(-0.5))
- w_ = F.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h*w)
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-class Encoder(nn.Module):
- def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.attn_resolutions = attn_resolutions
-
- curr_res = self.resolution
- in_ch_mult = (1,)+tuple(ch_mult)
-
- blocks = []
- # initial convultion
- blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
-
- # residual and downsampling blocks, with attention on smaller res (16x16)
- for i in range(self.num_resolutions):
- block_in_ch = nf * in_ch_mult[i]
- block_out_ch = nf * ch_mult[i]
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
- if curr_res in attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != self.num_resolutions - 1:
- blocks.append(Downsample(block_in_ch))
- curr_res = curr_res // 2
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- # normalise and convert to latent size
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
- self.blocks = nn.ModuleList(blocks)
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-class Generator(nn.Module):
- def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
- self.num_resolutions = len(self.ch_mult)
- self.num_res_blocks = res_blocks
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.in_channels = emb_dim
- self.out_channels = 3
- block_in_ch = self.nf * self.ch_mult[-1]
- curr_res = self.resolution // 2 ** (self.num_resolutions-1)
-
- blocks = []
- # initial conv
- blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- for i in reversed(range(self.num_resolutions)):
- block_out_ch = self.nf * self.ch_mult[i]
-
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
-
- if curr_res in self.attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != 0:
- blocks.append(Upsample(block_in_ch))
- curr_res = curr_res * 2
-
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
-
- self.blocks = nn.ModuleList(blocks)
-
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-@ARCH_REGISTRY.register()
-class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
- beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
- super().__init__()
- logger = get_root_logger()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
- self.codebook_size = codebook_size
- self.embed_dim = emb_dim
- self.ch_mult = ch_mult
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions or [16]
- self.quantizer_type = quantizer
- self.encoder = Encoder(
- self.in_channels,
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
- if self.quantizer_type == "nearest":
- self.beta = beta #0.25
- self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
- elif self.quantizer_type == "gumbel":
- self.gumbel_num_hiddens = emb_dim
- self.straight_through = gumbel_straight_through
- self.kl_weight = gumbel_kl_weight
- self.quantize = GumbelQuantizer(
- self.codebook_size,
- self.embed_dim,
- self.gumbel_num_hiddens,
- self.straight_through,
- self.kl_weight
- )
- self.generator = Generator(
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_ema' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
- logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- logger.info(f'vqgan is loaded from: {model_path} [params]')
- else:
- raise ValueError('Wrong params!')
-
-
- def forward(self, x):
- x = self.encoder(x)
- quant, codebook_loss, quant_stats = self.quantize(x)
- x = self.generator(quant)
- return x, codebook_loss, quant_stats
-
-
-
-# patch based discriminator
-@ARCH_REGISTRY.register()
-class VQGANDiscriminator(nn.Module):
- def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
- super().__init__()
-
- layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
- ndf_mult = 1
- ndf_mult_prev = 1
- for n in range(1, n_layers): # gradually increase the number of filters
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n, 8)
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n_layers, 8)
-
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- layers += [
- nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
- self.main = nn.Sequential(*layers)
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_d' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- else:
- raise ValueError('Wrong params!')
-
- def forward(self, x):
- return self.main(x)
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index da42b5e9..44b84618 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,132 +1,64 @@
-import os
+from __future__ import annotations
-import cv2
-import torch
-
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
-
-# codeformer people made a choice to include modified basicsr library to their project which makes
-# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
-# I am making a choice to include some files from codeformer to work around this issue.
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
-model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-
-codeformer = None
-
-
-def setup_model(dirname):
- os.makedirs(model_path, exist_ok=True)
-
- path = modules.paths.paths.get("CodeFormer", None)
- if path is None:
- return
-
- try:
- from torchvision.transforms.functional import normalize
- from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils import img2tensor, tensor2img
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
- from facelib.detection.retinaface import retinaface
-
- net_class = CodeFormer
-
- class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
- def name(self):
- return "CodeFormer"
-
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
+import logging
- def create_models(self):
-
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- print("Unable to load codeformer model.")
- return None, None
- net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
-
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
- face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
-
- self.net = net
- self.face_helper = face_helper
-
- return net, face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device)
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- np_image = np_image[:, :, ::-1]
-
- original_resolution = np_image.shape[0:2]
+import torch
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
- self.send_model_to(devices.device_codeformer)
+logger = logging.getLogger(__name__)
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_download_name = 'codeformer-v0.1.0.pth'
- for cropped_face in self.face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+# used by e.g. postprocessing_codeformer.py
+codeformer: face_restoration.FaceRestoration | None = None
- try:
- with torch.no_grad():
- output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- devices.torch_gc()
- except Exception:
- errors.report('Failed inference for CodeFormer', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
+class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "CodeFormer"
- self.face_helper.get_inverse_affine(None)
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ return modelloader.load_spandrel_model(
+ model_path,
+ device=devices.device_codeformer,
+ expected_architecture='CodeFormer',
+ ).model
+ raise ValueError("No codeformer model found")
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
+ def get_device(self):
+ return devices.device_codeformer
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ def restore(self, np_image, w: float | None = None):
+ if w is None:
+ w = getattr(shared.opts, "code_former_weight", 0.5)
- self.face_helper.clean_all()
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, w=w, adain=True)[0]
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
+ return self.restore_with_helper(np_image, restore_face)
- return restored_img
- global codeformer
+def setup_model(dirname: str) -> None:
+ global codeformer
+ try:
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
-
except Exception:
errors.report("Error setting up CodeFormer", exc_info=True)
-
- # sys.path = stored_sys_path
diff --git a/modules/devices.py b/modules/devices.py
index 1d4eb563..ff279ac5 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,10 +4,18 @@ from functools import lru_cache
import torch
from modules import errors, shared
+from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
+if shared.cmd_opts.use_ipex:
+ from modules import xpu_specific
+
+
+def has_xpu() -> bool:
+ return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
+
def has_mps() -> bool:
if sys.platform != "darwin":
@@ -16,6 +24,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -30,6 +55,9 @@ def get_optimal_device_name():
if has_mps():
return "mps"
+ if has_xpu():
+ return xpu_specific.get_xpu_device_string()
+
return "cpu"
@@ -38,7 +66,7 @@ def get_optimal_device():
def get_device_for(task):
- if task in shared.cmd_opts.use_cpu:
+ if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
@@ -54,14 +82,16 @@ def torch_gc():
if has_mps():
mac_specific.torch_mps_gc()
+ if has_xpu():
+ xpu_specific.torch_xpu_gc()
+
def enable_tf32():
if torch.cuda.is_available():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -71,6 +101,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -91,12 +122,51 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+
+def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = torch_utils.get_param(self).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+
+
+@contextlib.contextmanager
+def manual_cast():
+ for module_type in patch_module_list:
+ org_forward = module_type.forward
+ module_type.forward = manual_cast_forward
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ for module_type in patch_module_list:
+ module_type.forward = module_type.org_forward
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
+ return manual_cast()
+
+ if has_mps() and shared.cmd_opts.precision != "full":
+ return manual_cast()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/errors.py b/modules/errors.py
index 8c339464..48aa13a1 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -6,6 +6,21 @@ import traceback
exception_records = []
+def format_traceback(tb):
+ return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
+
+
+def format_exception(e, tb):
+ return {"exception": str(e), "traceback": format_traceback(tb)}
+
+
+def get_exceptions():
+ try:
+ return list(reversed(exception_records))
+ except Exception as e:
+ return str(e)
+
+
def record_exception():
_, e, tb = sys.exc_info()
if e is None:
@@ -14,8 +29,7 @@ def record_exception():
if exception_records and exception_records[-1] == e:
return
- from modules import sysinfo
- exception_records.append(sysinfo.format_exception(e, tb))
+ exception_records.append(format_exception(e, tb))
if len(exception_records) > 5:
exception_records.pop(0)
@@ -93,8 +107,8 @@ def check_versions():
import torch
import gradio
- expected_torch_version = "2.0.0"
- expected_xformers_version = "0.0.20"
+ expected_torch_version = "2.1.2"
+ expected_xformers_version = "0.0.23.post1"
expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 02a1727d..70041ab0 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,121 +1,7 @@
-import sys
-
-import numpy as np
-import torch
-from PIL import Image
-
-import modules.esrgan_model_arch as arch
-from modules import modelloader, images, devices
+from modules import modelloader, devices, errors
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
-
-
-def mod2normal(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- if 'conv_first.weight' in state_dict:
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
- crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
- crt_net['model.3.weight'] = state_dict['upconv1.weight']
- crt_net['model.3.bias'] = state_dict['upconv1.bias']
- crt_net['model.6.weight'] = state_dict['upconv2.weight']
- crt_net['model.6.bias'] = state_dict['upconv2.bias']
- crt_net['model.8.weight'] = state_dict['HRconv.weight']
- crt_net['model.8.bias'] = state_dict['HRconv.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
- state_dict = crt_net
- return state_dict
-
-
-def resrgan2normal(state_dict, nb=23):
- # this code is copied from https://github.com/victorca25/iNNfer
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
- re8x = 0
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if "rdb" in k:
- ori_k = k.replace('body.', 'model.1.sub.')
- ori_k = ori_k.replace('.rdb', '.RDB')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
- crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
- crt_net['model.3.weight'] = state_dict['conv_up1.weight']
- crt_net['model.3.bias'] = state_dict['conv_up1.bias']
- crt_net['model.6.weight'] = state_dict['conv_up2.weight']
- crt_net['model.6.bias'] = state_dict['conv_up2.bias']
-
- if 'conv_up3.weight' in state_dict:
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
- re8x = 3
- crt_net['model.9.weight'] = state_dict['conv_up3.weight']
- crt_net['model.9.bias'] = state_dict['conv_up3.bias']
-
- crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
- crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
- crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
- crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
-
- state_dict = crt_net
- return state_dict
-
-
-def infer_params(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- scale2x = 0
- scalemin = 6
- n_uplayer = 0
- plus = False
-
- for block in list(state_dict):
- parts = block.split(".")
- n_parts = len(parts)
- if n_parts == 5 and parts[2] == "sub":
- nb = int(parts[3])
- elif n_parts == 3:
- part_num = int(parts[1])
- if (part_num > scalemin
- and parts[0] == "model"
- and parts[2] == "weight"):
- scale2x += 1
- if part_num > n_uplayer:
- n_uplayer = part_num
- out_nc = state_dict[block].shape[0]
- if not plus and "conv1x1" in block:
- plus = True
-
- nf = state_dict["model.0.weight"].shape[0]
- in_nc = state_dict["model.0.weight"].shape[1]
- out_nc = out_nc
- scale = 2 ** scale2x
-
- return in_nc, out_nc, nf, nb, plus, scale
+from modules.upscaler_utils import upscale_with_model
class UpscalerESRGAN(Upscaler):
@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model)
- except Exception as e:
- print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
+ except Exception:
+ errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img
model.to(devices.device_esrgan)
- img = esrgan_upscale(model, img)
- return img
+ return esrgan_upscale(model, img)
def load_model(self, path: str):
if path.startswith("http"):
@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
else:
filename = path
- state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
-
- if "params_ema" in state_dict:
- state_dict = state_dict["params_ema"]
- elif "params" in state_dict:
- state_dict = state_dict["params"]
- num_conv = 16 if "realesr-animevideov3" in filename else 32
- model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
- model.load_state_dict(state_dict)
- model.eval()
- return model
-
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
- state_dict = resrgan2normal(state_dict, nb)
- elif "conv_first.weight" in state_dict:
- state_dict = mod2normal(state_dict)
- elif "model.0.weight" not in state_dict:
- raise Exception("The file is not a recognized ESRGAN model.")
-
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
-
- model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
- model.load_state_dict(state_dict)
- model.eval()
-
- return model
-
-
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
+ return modelloader.load_spandrel_model(
+ filename,
+ device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+ expected_architecture='ESRGAN',
+ )
def esrgan_upscale(model, img):
- if opts.ESRGAN_tile == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
-
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
-
- newrow.append([x * scale_factor, w * scale_factor, output])
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ )
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
deleted file mode 100644
index 2b9888ba..00000000
--- a/modules/esrgan_model_arch.py
+++ /dev/null
@@ -1,465 +0,0 @@
-# this file is adapted from https://github.com/victorca25/iNNfer
-
-from collections import OrderedDict
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-####################
-# RRDBNet Generator
-####################
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
- act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
- finalact=None, gaussian_noise=False, plus=False):
- super(RRDBNet, self).__init__()
- n_upscale = int(math.log(upscale, 2))
- if upscale == 3:
- n_upscale = 1
-
- self.resrgan_scale = 0
- if in_nc % 16 == 0:
- self.resrgan_scale = 1
- elif in_nc != 4 and in_nc % 4 == 0:
- self.resrgan_scale = 2
-
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
- rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
- LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
-
- if upsample_mode == 'upconv':
- upsample_block = upconv_block
- elif upsample_mode == 'pixelshuffle':
- upsample_block = pixelshuffle_block
- else:
- raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
- if upscale == 3:
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
- else:
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
-
- outact = act(finalact) if finalact else None
-
- self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
- *upsampler, HR_conv0, HR_conv1, outact)
-
- def forward(self, x, outm=None):
- if self.resrgan_scale == 1:
- feat = pixel_unshuffle(x, scale=4)
- elif self.resrgan_scale == 2:
- feat = pixel_unshuffle(x, scale=2)
- else:
- feat = x
-
- return self.model(feat)
-
-
-class RRDB(nn.Module):
- """
- Residual in Residual Dense Block
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
- """
-
- def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(RRDB, self).__init__()
- # This is for backwards compatibility with existing models
- if nr == 3:
- self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- else:
- RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
- self.RDBs = nn.Sequential(*RDB_list)
-
- def forward(self, x):
- if hasattr(self, 'RDB1'):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- else:
- out = self.RDBs(x)
- return out * 0.2 + x
-
-
-class ResidualDenseBlock_5C(nn.Module):
- """
- Residual Dense Block
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
- Modified options that can be used:
- - "Partial Convolution based Padding" arXiv:1811.11718
- - "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
- {Rakotonirina} and A. {Rasoanaivo}
- """
-
- def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(ResidualDenseBlock_5C, self).__init__()
-
- self.noise = GaussianNoise() if gaussian_noise else None
- self.conv1x1 = conv1x1(nf, gc) if plus else None
-
- self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- if mode == 'CNA':
- last_act = None
- else:
- last_act = act_type
- self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(torch.cat((x, x1), 1))
- if self.conv1x1:
- x2 = x2 + self.conv1x1(x)
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
- if self.conv1x1:
- x4 = x4 + x2
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- if self.noise:
- return self.noise(x5.mul(0.2) + x)
- else:
- return x5 * 0.2 + x
-
-
-####################
-# ESRGANplus
-####################
-
-class GaussianNoise(nn.Module):
- def __init__(self, sigma=0.1, is_relative_detach=False):
- super().__init__()
- self.sigma = sigma
- self.is_relative_detach = is_relative_detach
- self.noise = torch.tensor(0, dtype=torch.float)
-
- def forward(self, x):
- if self.training and self.sigma != 0:
- self.noise = self.noise.to(x.device)
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
- x = x + sampled_noise
- return x
-
-def conv1x1(in_planes, out_planes, stride=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-####################
-# SRVGGNetCompact
-####################
-
-class SRVGGNetCompact(nn.Module):
- """A compact VGG-style network structure for super-resolution.
- This class is copied from https://github.com/xinntao/Real-ESRGAN
- """
-
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
- self.num_in_ch = num_in_ch
- self.num_out_ch = num_out_ch
- self.num_feat = num_feat
- self.num_conv = num_conv
- self.upscale = upscale
- self.act_type = act_type
-
- self.body = nn.ModuleList()
- # the first conv
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
- # the first activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the body structure
- for _ in range(num_conv):
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
- # activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the last conv
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
- # upsample
- self.upsampler = nn.PixelShuffle(upscale)
-
- def forward(self, x):
- out = x
- for i in range(0, len(self.body)):
- out = self.body[i](out)
-
- out = self.upsampler(out)
- # add the nearest upsampled image, so that the network learns the residual
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
- out += base
- return out
-
-
-####################
-# Upsampler
-####################
-
-class Upsample(nn.Module):
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
- The input data is assumed to be of the form
- `minibatch x channels x [optional depth] x [optional height] x width`.
- """
-
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
- super(Upsample, self).__init__()
- if isinstance(scale_factor, tuple):
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
- else:
- self.scale_factor = float(scale_factor) if scale_factor else None
- self.mode = mode
- self.size = size
- self.align_corners = align_corners
-
- def forward(self, x):
- return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
-
- def extra_repr(self):
- if self.scale_factor is not None:
- info = f'scale_factor={self.scale_factor}'
- else:
- info = f'size={self.size}'
- info += f', mode={self.mode}'
- return info
-
-
-def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
-
-
-def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
- """
- Pixel shuffle layer
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
- Neural Network, CVPR17)
- """
- conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
-
- n = norm(norm_type, out_nc) if norm_type else None
- a = act(act_type) if act_type else None
- return sequential(conv, pixel_shuffle, n, a)
-
-
-def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
- """ Upconv layer """
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
- conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
- return sequential(upsample, conv)
-
-
-
-
-
-
-
-
-####################
-# Basic blocks
-####################
-
-
-def make_layer(basic_block, num_basic_block, **kwarg):
- """Make layers by stacking the same blocks.
- Args:
- basic_block (nn.module): nn.module class for basic block. (block)
- num_basic_block (int): number of blocks. (n_layers)
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
-
-
-def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
- """ activation helper """
- act_type = act_type.lower()
- if act_type == 'relu':
- layer = nn.ReLU(inplace)
- elif act_type in ('leakyrelu', 'lrelu'):
- layer = nn.LeakyReLU(neg_slope, inplace)
- elif act_type == 'prelu':
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
- elif act_type == 'tanh': # [-1, 1] range output
- layer = nn.Tanh()
- elif act_type == 'sigmoid': # [0, 1] range output
- layer = nn.Sigmoid()
- else:
- raise NotImplementedError(f'activation layer [{act_type}] is not found')
- return layer
-
-
-class Identity(nn.Module):
- def __init__(self, *kwargs):
- super(Identity, self).__init__()
-
- def forward(self, x, *kwargs):
- return x
-
-
-def norm(norm_type, nc):
- """ Return a normalization layer """
- norm_type = norm_type.lower()
- if norm_type == 'batch':
- layer = nn.BatchNorm2d(nc, affine=True)
- elif norm_type == 'instance':
- layer = nn.InstanceNorm2d(nc, affine=False)
- elif norm_type == 'none':
- def norm_layer(x): return Identity()
- else:
- raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
- return layer
-
-
-def pad(pad_type, padding):
- """ padding layer helper """
- pad_type = pad_type.lower()
- if padding == 0:
- return None
- if pad_type == 'reflect':
- layer = nn.ReflectionPad2d(padding)
- elif pad_type == 'replicate':
- layer = nn.ReplicationPad2d(padding)
- elif pad_type == 'zero':
- layer = nn.ZeroPad2d(padding)
- else:
- raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
- return layer
-
-
-def get_valid_padding(kernel_size, dilation):
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- padding = (kernel_size - 1) // 2
- return padding
-
-
-class ShortcutBlock(nn.Module):
- """ Elementwise sum the output of a submodule to its input """
- def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = x + self.sub(x)
- return output
-
- def __repr__(self):
- return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
-
-
-def sequential(*args):
- """ Flatten Sequential. It unwraps nn.Sequential. """
- if len(args) == 1:
- if isinstance(args[0], OrderedDict):
- raise NotImplementedError('sequential does not support OrderedDict input.')
- return args[0] # No sequential is needed.
- modules = []
- for module in args:
- if isinstance(module, nn.Sequential):
- for submodule in module.children():
- modules.append(submodule)
- elif isinstance(module, nn.Module):
- modules.append(module)
- return nn.Sequential(*modules)
-
-
-def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
- spectral_norm=False):
- """ Conv layer with padding, normalization, activation """
- assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
- padding = get_valid_padding(kernel_size, dilation)
- p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
- padding = padding if pad_type == 'zero' else 0
-
- if convtype=='PartialConv2D':
- from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
- c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='DeformConv2D':
- from torchvision.ops import DeformConv2d # not tested
- c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='Conv3D':
- c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- else:
- c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
-
- if spectral_norm:
- c = nn.utils.spectral_norm(c)
-
- a = act(act_type) if act_type else None
- if 'CNA' in mode:
- n = norm(norm_type, out_nc) if norm_type else None
- return sequential(p, c, n, a)
- elif mode == 'NAC':
- if norm_type is None and act_type is not None:
- a = act(act_type, inplace=False)
- n = norm(norm_type, in_nc) if norm_type else None
- return sequential(n, a, p, c)
diff --git a/modules/extensions.py b/modules/extensions.py
index bf9a1878..1899cd52 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,11 +1,14 @@
+from __future__ import annotations
+
+import configparser
import os
import threading
+import re
from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
-extensions = []
os.makedirs(extensions_dir, exist_ok=True)
@@ -19,11 +22,55 @@ def active():
return [x for x in extensions if x.enabled]
+class ExtensionMetadata:
+ filename = "metadata.ini"
+ config: configparser.ConfigParser
+ canonical_name: str
+ requires: list
+
+ def __init__(self, path, canonical_name):
+ self.config = configparser.ConfigParser()
+
+ filepath = os.path.join(path, self.filename)
+ if os.path.isfile(filepath):
+ try:
+ self.config.read(filepath)
+ except Exception:
+ errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
+
+ self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
+ self.canonical_name = canonical_name.lower().strip()
+
+ self.requires = self.get_script_requirements("Requires", "Extension")
+
+ def get_script_requirements(self, field, section, extra_section=None):
+ """reads a list of requirements from the config; field is the name of the field in the ini file,
+ like Requires or Before, and section is the name of the [section] in the ini file; additionally,
+ reads more requirements from [extra_section] if specified."""
+
+ x = self.config.get(section, field, fallback='')
+
+ if extra_section:
+ x = x + ', ' + self.config.get(extra_section, field, fallback='')
+
+ return self.parse_list(x.lower())
+
+ def parse_list(self, text):
+ """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
+
+ if not text:
+ return []
+
+ # both "," and " " are accepted as separator
+ return [x for x in re.split(r"[,\s]+", text.strip()) if x]
+
+
class Extension:
lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
+ metadata: ExtensionMetadata
- def __init__(self, name, path, enabled=True, is_builtin=False):
+ def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name
self.path = path
self.enabled = enabled
@@ -36,6 +83,8 @@ class Extension:
self.branch = None
self.remote = None
self.have_info_from_repo = False
+ self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
+ self.canonical_name = metadata.canonical_name
def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields}
@@ -56,6 +105,7 @@ class Extension:
self.do_read_info_from_repo()
return self.to_dict()
+
try:
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
self.from_dict(d)
@@ -136,9 +186,6 @@ class Extension:
def list_extensions():
extensions.clear()
- if not os.path.isdir(extensions_dir):
- return
-
if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all":
@@ -148,18 +195,43 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
- extension_paths = []
- for dirname in [extensions_dir, extensions_builtin_dir]:
+ loaded_extensions = {}
+
+ # scan through extensions directory and load metadata
+ for dirname in [extensions_builtin_dir, extensions_dir]:
if not os.path.isdir(dirname):
- return
+ continue
for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path):
continue
- extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+ canonical_name = extension_dirname
+ metadata = ExtensionMetadata(path, canonical_name)
+
+ # check for duplicated canonical names
+ already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
+ if already_loaded_extension is not None:
+ errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
+ continue
+
+ is_builtin = dirname == extensions_builtin_dir
+ extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
+ extensions.append(extension)
+ loaded_extensions[canonical_name] = extension
+
+ # check for requirements
+ for extension in extensions:
+ for req in extension.metadata.requires:
+ required_extension = loaded_extensions.get(req)
+ if required_extension is None:
+ errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
+ continue
+
+ if not extension.enabled:
+ errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
+ continue
+
- for dirname, path, is_builtin in extension_paths:
- extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
- extensions.append(extension)
+extensions: list[Extension] = []
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
new file mode 100644
index 00000000..1cbac236
--- /dev/null
+++ b/modules/face_restoration_utils.py
@@ -0,0 +1,180 @@
+from __future__ import annotations
+
+import logging
+import os
+from functools import cached_property
+from typing import TYPE_CHECKING, Callable
+
+import cv2
+import numpy as np
+import torch
+
+from modules import devices, errors, face_restoration, shared
+
+if TYPE_CHECKING:
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+
+logger = logging.getLogger(__name__)
+
+
+def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
+ """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
+ assert img.shape[2] == 3, "image must be RGB"
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return torch.from_numpy(img.transpose(2, 0, 1)).float()
+
+
+def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
+ """
+ Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
+ """
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
+ assert tensor.dim() == 3, "tensor must be RGB"
+ img_np = tensor.numpy().transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image, no RGB/BGR required
+ return np.squeeze(img_np, axis=2)
+ return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
+
+
+def create_face_helper(device) -> FaceRestoreHelper:
+ from facexlib.detection import retinaface
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+ if hasattr(retinaface, 'device'):
+ retinaface.device = device
+ return FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,
+ device=device,
+ )
+
+
+def restore_with_face_helper(
+ np_image: np.ndarray,
+ face_helper: FaceRestoreHelper,
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
+) -> np.ndarray:
+ """
+ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
+
+ `restore_face` should take a cropped face image and return a restored face image.
+ """
+ from torchvision.transforms.functional import normalize
+ np_image = np_image[:, :, ::-1]
+ original_resolution = np_image.shape[0:2]
+
+ try:
+ logger.debug("Detecting faces...")
+ face_helper.clean_all()
+ face_helper.read_image(np_image)
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ face_helper.align_warp_face()
+ logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
+ for cropped_face in face_helper.cropped_faces:
+ cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+
+ try:
+ with torch.no_grad():
+ cropped_face_t = restore_face(cropped_face_t)
+ devices.torch_gc()
+ except Exception:
+ errors.report('Failed face-restoration inference', exc_info=True)
+
+ restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
+ restored_face = (restored_face * 255.0).astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ logger.debug("Merging restored faces into image")
+ face_helper.get_inverse_affine(None)
+ img = face_helper.paste_faces_to_input_image()
+ img = img[:, :, ::-1]
+ if original_resolution != img.shape[0:2]:
+ img = cv2.resize(
+ img,
+ (0, 0),
+ fx=original_resolution[1] / img.shape[1],
+ fy=original_resolution[0] / img.shape[0],
+ interpolation=cv2.INTER_LINEAR,
+ )
+ logger.debug("Face restoration complete")
+ finally:
+ face_helper.clean_all()
+ return img
+
+
+class CommonFaceRestoration(face_restoration.FaceRestoration):
+ net: torch.Module | None
+ model_url: str
+ model_download_name: str
+
+ def __init__(self, model_path: str):
+ super().__init__()
+ self.net = None
+ self.model_path = model_path
+ os.makedirs(model_path, exist_ok=True)
+
+ @cached_property
+ def face_helper(self) -> FaceRestoreHelper:
+ return create_face_helper(self.get_device())
+
+ def send_model_to(self, device):
+ if self.net:
+ logger.debug("Sending %s to %s", self.net, device)
+ self.net.to(device)
+ if self.face_helper:
+ logger.debug("Sending face helper to %s", device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
+ def get_device(self):
+ raise NotImplementedError("get_device must be implemented by subclasses")
+
+ def load_net(self) -> torch.Module:
+ raise NotImplementedError("load_net must be implemented by subclasses")
+
+ def restore_with_helper(
+ self,
+ np_image: np.ndarray,
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
+ ) -> np.ndarray:
+ try:
+ if self.net is None:
+ self.net = self.load_net()
+ except Exception:
+ logger.warning("Unable to load face-restoration model", exc_info=True)
+ return np_image
+
+ try:
+ self.send_model_to(self.get_device())
+ return restore_with_face_helper(np_image, self.face_helper, restore_face)
+ finally:
+ if shared.opts.face_restoration_unload:
+ self.send_model_to(devices.cpu)
+
+
+def patch_facexlib(dirname: str) -> None:
+ import facexlib.detection
+ import facexlib.parsing
+
+ det_facex_load_file_from_url = facexlib.detection.load_file_from_url
+ par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
+
+ def update_kwargs(kwargs):
+ return dict(kwargs, save_dir=dirname, model_dir=None)
+
+ def facex_load_file_from_url(**kwargs):
+ return det_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ def facex_load_file_from_url2(**kwargs):
+ return par_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 8e0f13bd..445b0409 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,110 +1,71 @@
+from __future__ import annotations
+
+import logging
import os
-import facexlib
-import gfpgan
+import torch
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
+logger = logging.getLogger(__name__)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- global loaded_gfpgan_model
- global model_path
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
-
- if gfpgan_constructor is None:
- return None
-
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and models[0].startswith("http"):
- model_file = models[0]
- elif len(models) != 0:
- latest_file = max(models, key=os.path.getctime)
- model_file = latest_file
- else:
- print("Unable to load gfpgan model!")
- return None
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
-
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
+model_download_name = "GFPGANv1.4.pth"
+gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
+
+
+class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def get_device(self):
+ return devices.device_gfpgan
+
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ if 'GFPGAN' in os.path.basename(model_path):
+ model = modelloader.load_spandrel_model(
+ model_path,
+ device=self.get_device(),
+ expected_architecture='GFPGAN',
+ ).model
+ model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
+ return model
+ raise ValueError("No GFPGAN model found")
+
+ def restore(self, np_image):
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, return_rgb=False)[0]
+
+ return self.restore_with_helper(np_image, restore_face)
def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
+ if gfpgan_face_restorer:
+ return gfpgan_face_restorer.restore(np_image)
+ logger.warning("GFPGAN face restorer not set up")
return np_image
-gfpgan_constructor = None
-
+def setup_model(dirname: str) -> None:
+ global gfpgan_face_restorer
-def setup_model(dirname):
try:
- os.makedirs(model_path, exist_ok=True)
- from gfpgan import GFPGANer
- from facexlib import detection, parsing # noqa: F401
- global user_path
- global have_gfpgan
- global gfpgan_constructor
-
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
+ face_restoration_utils.patch_facexlib(dirname)
+ gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
+ shared.face_restorers.append(gfpgan_face_restorer)
except Exception:
errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
index e6b6835a..7d88dc98 100644
--- a/modules/gradio_extensons.py
+++ b/modules/gradio_extensons.py
@@ -47,10 +47,20 @@ def Block_get_config(self):
def BlockContext_init(self, *args, **kwargs):
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.before_component(self, **kwargs)
+
+ scripts.script_callbacks.before_component_callback(self, **kwargs)
+
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
+ scripts.script_callbacks.after_component_callback(self, **kwargs)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.after_component(self, **kwargs)
+
return res
diff --git a/modules/hat_model.py b/modules/hat_model.py
new file mode 100644
index 00000000..7f2abb41
--- /dev/null
+++ b/modules/hat_model.py
@@ -0,0 +1,43 @@
+import os
+import sys
+
+from modules import modelloader, devices
+from modules.shared import opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerHAT(Upscaler):
+ def __init__(self, dirname):
+ self.name = "HAT"
+ self.scalers = []
+ self.user_path = dirname
+ super().__init__()
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
+ return img
+ model.to(devices.device_esrgan) # TODO: should probably be device_hat
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
+ tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
+ )
+
+ def load_model(self, path: str):
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"Model file {path} not found")
+ return modelloader.load_spandrel_model(
+ path,
+ device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
+ )
diff --git a/modules/images.py b/modules/images.py
index daf4eebe..87a7bf22 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
return grid
-Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
+ @property
+ def tile_count(self) -> int:
+ """
+ The total number of tiles in the grid.
+ """
+ return sum(len(row[2]) for row in self.tiles)
-def split_grid(image, tile_w=512, tile_h=512, overlap=64):
- w = image.width
- h = image.height
+def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
+ w, h = image.size
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
@@ -791,3 +796,4 @@ def flatten(img, bgcolor):
img = background
return img.convert('RGB')
+
diff --git a/modules/img2img.py b/modules/img2img.py
index 52cb577a..e7e8e251 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr
import gradio as gr
from modules import images as imgutil
-from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
+from modules.infotext import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
from modules.sd_models import get_closet_checkpoint_match
@@ -44,12 +44,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
steps = p.steps
override_settings = p.override_settings
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
+ batch_results = None
+ discard_further_results = False
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
try:
@@ -127,7 +129,21 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if proc is None:
p.override_settings.pop('save_images_replace_action', None)
- process_images(p)
+ proc = process_images(p)
+
+ if not discard_further_results and proc:
+ if batch_results:
+ batch_results.images.extend(proc.images)
+ batch_results.infotexts.extend(proc.infotexts)
+ else:
+ batch_results = proc
+
+ if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
+ discard_further_results = True
+ batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
+ batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
+
+ return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
@@ -212,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
with closing(p):
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
-
- processed = Processed(p, [], p.seed, "")
+ if processed is None:
+ processed = Processed(p, [], p.seed, "")
else:
processed = modules.scripts.scripts_img2img.run(p, *args)
if processed is None:
diff --git a/modules/import_hook.py b/modules/import_hook.py
index 28c67dfa..eba9a372 100644
--- a/modules/import_hook.py
+++ b/modules/import_hook.py
@@ -3,3 +3,14 @@ import sys
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
if "--xformers" not in "".join(sys.argv):
sys.modules["xformers"] = None
+
+# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
+# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
+try:
+ import torchvision.transforms.functional_tensor # noqa: F401
+except ImportError:
+ try:
+ import torchvision.transforms.functional as functional
+ sys.modules["torchvision.transforms.functional_tensor"] = functional
+ except ImportError:
+ pass # shrug...
diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext.py
index 0a606515..26e9b949 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/infotext.py
@@ -1,23 +1,24 @@
+from __future__ import annotations
import base64
import io
import json
import os
import re
+import sys
import gradio as gr
from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks, processing
+from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
from PIL import Image
+sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
+
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
-paste_fields = {}
-registered_param_bindings = []
-
class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
@@ -30,6 +31,23 @@ class ParamBinding:
self.paste_field_names = paste_field_names or []
+class PasteField(tuple):
+ def __new__(cls, component, target, *, api=None):
+ return super().__new__(cls, (component, target))
+
+ def __init__(self, component, target, *, api=None):
+ super().__init__()
+
+ self.api = api
+ self.component = component
+ self.label = target if isinstance(target, str) else None
+ self.function = target if callable(target) else None
+
+
+paste_fields: dict[str, dict] = {}
+registered_param_bindings: list[ParamBinding] = []
+
+
def reset():
paste_fields.clear()
registered_param_bindings.clear()
@@ -82,6 +100,12 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+
+ if fields:
+ for i in range(len(fields)):
+ if not isinstance(fields[i], PasteField):
+ fields[i] = PasteField(*fields[i])
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
@@ -113,7 +137,6 @@ def register_paste_params_button(binding: ParamBinding):
def connect_paste_params_buttons():
- binding: ParamBinding
for binding in registered_param_bindings:
destination_image_component = paste_fields[binding.tabname]["init_img"]
fields = paste_fields[binding.tabname]["fields"]
@@ -313,6 +336,17 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"
+ if "FP8 weight" not in res:
+ res["FP8 weight"] = "Disable"
+
+ if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
+ res["Cache FP16 weight for LoRA"] = False
+
+ infotext_versions.backcompat(res)
+
+ skip = set(shared.opts.infotext_skip_pasting)
+ res = {k: v for k, v in res.items() if k not in skip}
+
return res
@@ -361,6 +395,48 @@ def create_override_settings_dict(text_pairs):
return res
+def get_override_settings(params, *, skip_fields=None):
+ """Returns a list of settings overrides from the infotext parameters dictionary.
+
+ This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
+ a list of tuples containing the parameter name, setting name, and new value cast to correct type.
+
+ It checks for conditions before adding an override:
+ - ignores settings that match the current value
+ - ignores parameter keys present in skip_fields argument.
+
+ Example input:
+ {"Clip skip": "2"}
+
+ Example output:
+ [("Clip skip", "CLIP_stop_at_last_layers", 2)]
+ """
+
+ res = []
+
+ mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
+ for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
+ if param_name in (skip_fields or {}):
+ continue
+
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ res.append((param_name, setting_name, v))
+
+ return res
+
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
@@ -402,29 +478,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
- vals = {}
-
- mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
- for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
- if param_name in already_handled_fields:
- continue
-
- v = params.get(param_name, None)
- if v is None:
- continue
+ vals = get_override_settings(params, skip_fields=already_handled_fields)
- if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
- continue
-
- v = shared.opts.cast_value(setting_name, v)
- current_value = getattr(shared.opts, setting_name, None)
-
- if v == current_value:
- continue
-
- vals[param_name] = v
-
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+ vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
@@ -443,3 +499,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[],
show_progress=False,
)
+
diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py
new file mode 100644
index 00000000..a5afeebf
--- /dev/null
+++ b/modules/infotext_versions.py
@@ -0,0 +1,39 @@
+from modules import shared
+from packaging import version
+import re
+
+
+v160 = version.parse("1.6.0")
+v170_tsnr = version.parse("v1.7.0-225")
+
+
+def parse_version(text):
+ if text is None:
+ return None
+
+ m = re.match(r'([^-]+-[^-]+)-.*', text)
+ if m:
+ text = m.group(1)
+
+ try:
+ return version.parse(text)
+ except Exception:
+ return None
+
+
+def backcompat(d):
+ """Checks infotext Version field, and enables backwards compatibility options according to it."""
+
+ if not shared.opts.auto_backcompat:
+ return
+
+ ver = parse_version(d.get("Version"))
+ if ver is None:
+ return
+
+ if ver < v160:
+ d["Old prompt editing timelines"] = True
+
+ if ver < v170_tsnr:
+ d["Downcast alphas_cumprod"] = True
+
diff --git a/modules/initialize.py b/modules/initialize.py
index ac95fc6f..4a3cd98c 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -54,9 +54,6 @@ def initialize():
initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange()
- from modules import modelloader
- modelloader.cleanup_models()
-
from modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 2e9b6d89..b6767138 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -177,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 3045560d..35a627ca 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -10,7 +10,7 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -131,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate)
- self.dtype = next(self.clip_model.parameters()).dtype
+ self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 8cdbafa5..c2cbd8ce 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -6,6 +6,7 @@ import os
import shutil
import sys
import importlib.util
+import importlib.metadata
import platform
import json
from functools import lru_cache
@@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_
def is_installed(package):
try:
- spec = importlib.util.find_spec(package)
- except ModuleNotFoundError:
- return False
+ dist = importlib.metadata.distribution(package)
+ except importlib.metadata.PackageNotFoundError:
+ try:
+ spec = importlib.util.find_spec(package)
+ except ModuleNotFoundError:
+ return False
+
+ return spec is not None
- return spec is not None
+ return dist is not None
def repo_dir(name):
@@ -308,24 +314,42 @@ def requirements_met(requirements_file):
def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
+ if args.use_ipex:
+ if platform.system() == "Windows":
+ # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
+ # This is NOT an Intel official release so please use it at your own risk!!
+ # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
+ #
+ # Strengths (over official IPEX 2.0.110 windows release):
+ # - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
+ # - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
+ # - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
+ # Limitation:
+ # - Only works for python 3.10
+ url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
+ else:
+ # Using official IPEX release for linux since it's already an AOT build.
+ # However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
+ # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
- codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
- codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
@@ -352,6 +376,8 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch")
+ if args.use_ipex:
+ args.skip_torch_cuda_test = True
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError(
'Torch is not able to use GPU; '
@@ -380,15 +406,10 @@ def prepare_environment():
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
- git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores")
- if not is_installed("lpips"):
- run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
- startup_timer.record("install CodeFormer requirements")
-
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
@@ -441,7 +462,7 @@ def dump_sysinfo():
import datetime
text = sysinfo.get()
- filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
with open(filename, "w", encoding="utf8") as file:
file.write(text)
diff --git a/modules/logging_config.py b/modules/logging_config.py
index 7db23d4b..79269875 100644
--- a/modules/logging_config.py
+++ b/modules/logging_config.py
@@ -1,16 +1,41 @@
import os
import logging
+try:
+ from tqdm.auto import tqdm
+
+ class TqdmLoggingHandler(logging.Handler):
+ def __init__(self, level=logging.INFO):
+ super().__init__(level)
+
+ def emit(self, record):
+ try:
+ msg = self.format(record)
+ tqdm.write(msg)
+ self.flush()
+ except Exception:
+ self.handleError(record)
+
+ TQDM_IMPORTED = True
+except ImportError:
+ # tqdm does not exist before first launch
+ # I will import once the UI finishes seting up the enviroment and reloads.
+ TQDM_IMPORTED = False
def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
+ loghandlers = []
+
+ if TQDM_IMPORTED:
+ loghandlers.append(TqdmLoggingHandler())
+
if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=loghandlers
)
-
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 89256c5b..d96d86d7 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,6 +1,7 @@
import logging
import torch
+from torch import Tensor
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs)
+# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
+def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
+ try:
+ return orig_func(*args, **kwargs)
+ except RuntimeError as e:
+ if "not implemented for" in str(e) and "Half" in str(e):
+ input_tensor = args[0]
+ return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
+ else:
+ print(f"An unexpected RuntimeError occurred: {str(e)}")
+
if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
@@ -77,6 +89,9 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
+ # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
+ CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 098bcb79..a7194137 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,13 +1,20 @@
from __future__ import annotations
-import os
-import shutil
import importlib
+import logging
+import os
+from typing import TYPE_CHECKING
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
-from modules.paths import script_path, models_path
+
+if TYPE_CHECKING:
+ import spandrel
+
+logger = logging.getLogger(__name__)
def load_file_from_url(
@@ -90,54 +97,6 @@ def friendly_name(file: str):
return model_name
-def cleanup_models():
- # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
- # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
- # somehow auto-register and just do these things...
- root_path = script_path
- src_path = models_path
- dest_path = os.path.join(models_path, "Stable-diffusion")
- move_files(src_path, dest_path, ".ckpt")
- move_files(src_path, dest_path, ".safetensors")
- src_path = os.path.join(root_path, "ESRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(models_path, "BSRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path, ".pth")
- src_path = os.path.join(root_path, "gfpgan")
- dest_path = os.path.join(models_path, "GFPGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "SwinIR")
- dest_path = os.path.join(models_path, "SwinIR")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
- dest_path = os.path.join(models_path, "LDSR")
- move_files(src_path, dest_path)
-
-
-def move_files(src_path: str, dest_path: str, ext_filter: str = None):
- try:
- os.makedirs(dest_path, exist_ok=True)
- if os.path.exists(src_path):
- for file in os.listdir(src_path):
- fullpath = os.path.join(src_path, file)
- if os.path.isfile(fullpath):
- if ext_filter is not None:
- if ext_filter not in file:
- continue
- print(f"Moving {file} from {src_path} to {dest_path}.")
- try:
- shutil.move(fullpath, dest_path)
- except Exception:
- pass
- if len(os.listdir(src_path)) == 0:
- print(f"Removing empty folder: {src_path}")
- shutil.rmtree(src_path, True)
- except Exception:
- pass
-
-
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -177,3 +136,26 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
+
+
+def load_spandrel_model(
+ path: str,
+ *,
+ device: str | torch.device | None,
+ half: bool = False,
+ dtype: str | torch.dtype | None = None,
+ expected_architecture: str | None = None,
+) -> spandrel.ModelDescriptor:
+ import spandrel
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model_descriptor.architecture != expected_architecture:
+ logger.warning(
+ f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
+ )
+ if half:
+ model_descriptor.model.half()
+ if dtype:
+ model_descriptor.model.to(dtype=dtype)
+ model_descriptor.model.eval()
+ logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
+ return model_descriptor
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
index b892d5fc..6db340da 100644
--- a/modules/models/diffusion/ddpm_edit.py
+++ b/modules/models/diffusion/ddpm_edit.py
@@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
+try:
+ from ldm.models.autoencoder import VQModelInterface
+except Exception:
+ class VQModelInterface:
+ pass
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
diff --git a/modules/options.py b/modules/options.py
index e270a42d..09ff9403 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -1,5 +1,6 @@
import json
import sys
+from dataclasses import dataclass
import gradio as gr
@@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = section
+ self.category_id = category_id
self.refresh = refresh
self.do_not_save = False
@@ -63,7 +65,11 @@ class OptionHTML(OptionInfo):
def options_section(section_identifier, options_dict):
for v in options_dict.values():
- v.section = section_identifier
+ if len(section_identifier) == 2:
+ v.section = section_identifier
+ elif len(section_identifier) == 3:
+ v.section = section_identifier[0:2]
+ v.category_id = section_identifier[2]
return options_dict
@@ -76,7 +82,7 @@ class Options:
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels
- self.data = {k: v.default for k, v in self.data_labels.items()}
+ self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
self.restricted_opts = restricted_opts
def __setattr__(self, key, value):
@@ -175,7 +181,7 @@ class Options:
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
- json.dump(self.data, file, indent=4)
+ json.dump(self.data, file, indent=4, ensure_ascii=False)
def same_type(self, x, y):
if x is None or y is None:
@@ -223,21 +229,59 @@ class Options:
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
+
+ item_categories = {}
+ for item in self.data_labels.values():
+ category = categories.mapping.get(item.category_id)
+ category = "Uncategorized" if category is None else category.label
+ if category not in item_categories:
+ item_categories[category] = item.section[1]
+
+ # _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
+ d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
+
return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
+ if key not in self.data and not info.do_not_save:
+ self.data[key] = info.default
def reorder(self):
- """reorder settings so that all items related to section always go together"""
+ """Reorder settings so that:
+ - all items related to section always go together
+ - all sections belonging to a category go together
+ - sections inside a category are ordered alphabetically
+ - categories are ordered by creation order
+
+ Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
+
+ This function also changes items' category_id so that all items belonging to a section have the same category_id.
+ """
+
+ category_ids = {}
+ section_categories = {}
- section_ids = {}
settings_items = self.data_labels.items()
for _, item in settings_items:
- if item.section not in section_ids:
- section_ids[item.section] = len(section_ids)
+ if item.section not in section_categories:
+ section_categories[item.section] = item.category_id
+
+ for _, item in settings_items:
+ item.category_id = section_categories.get(item.section)
+
+ for category_id in categories.mapping:
+ if category_id not in category_ids:
+ category_ids[category_id] = len(category_ids)
- self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
+ def sort_key(x):
+ item: OptionInfo = x[1]
+ category_order = category_ids.get(item.category_id, len(category_ids))
+ section_order = item.section[1]
+
+ return category_order, section_order
+
+ self.data_labels = dict(sorted(settings_items, key=sort_key))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
@@ -260,3 +304,22 @@ class Options:
value = expected_type(value)
return value
+
+
+@dataclass
+class OptionsCategory:
+ id: str
+ label: str
+
+class OptionsCategories:
+ def __init__(self):
+ self.mapping = {}
+
+ def register_category(self, category_id, label):
+ if category_id in self.mapping:
+ return category_id
+
+ self.mapping[category_id] = OptionsCategory(category_id, label)
+
+
+categories = OptionsCategories()
diff --git a/modules/paths.py b/modules/paths.py
index 187b9496..03064651 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
- (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 89131a54..b86ecd7f 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
+default_output_dir = os.path.join(data_path, "output")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index cf04d38b..facea899 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -2,7 +2,7 @@ import os
from PIL import Image
-from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
+from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common
from modules.shared import opts
@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_list = shared.listfiles(input_dir)
for filename in image_list:
- try:
- image = Image.open(filename)
- except Exception:
- continue
- yield image, filename
+ yield filename, filename
else:
assert image, 'image not selected'
yield image, None
@@ -45,43 +41,97 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = ''
- for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
+ data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
+ shared.state.job_count = len(data_to_process)
+
+ for image_placeholder, name in data_to_process:
image_data: Image.Image
+ shared.state.nextjob()
shared.state.textinfo = name
+ shared.state.skipped = False
+
+ if shared.state.interrupted:
+ break
+
+ if isinstance(image_placeholder, str):
+ try:
+ image_data = Image.open(image_placeholder)
+ except Exception:
+ continue
+ else:
+ image_data = image_placeholder
+
+ shared.state.assign_current_image(image_data)
parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters:
existing_pnginfo["parameters"] = parameters
- pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
+ initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
- scripts.scripts_postproc.run(pp, args)
+ scripts.scripts_postproc.run(initial_pp, args)
- if opts.use_original_name_batch and name is not None:
- basename = os.path.splitext(os.path.basename(name))[0]
- else:
- basename = ''
+ if shared.state.skipped:
+ continue
+
+ used_suffixes = {}
+ for pp in [initial_pp, *initial_pp.extra_images]:
+ suffix = pp.get_suffix(used_suffixes)
- infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+ if opts.use_original_name_batch and name is not None:
+ basename = os.path.splitext(os.path.basename(name))[0]
+ forced_filename = basename + suffix
+ else:
+ basename = ''
+ forced_filename = None
- if opts.enable_pnginfo:
- pp.image.info = existing_pnginfo
- pp.image.info["postprocessing"] = infotext
+ infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None])
- if save_output:
- images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+ if opts.enable_pnginfo:
+ pp.image.info = existing_pnginfo
+ pp.image.info["postprocessing"] = infotext
- if extras_mode != 2 or show_extras_results:
- outputs.append(pp.image)
+ if save_output:
+ fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
+
+ if pp.caption:
+ caption_filename = os.path.splitext(fullfn)[0] + ".txt"
+ if os.path.isfile(caption_filename):
+ with open(caption_filename, encoding="utf8") as file:
+ existing_caption = file.read().strip()
+ else:
+ existing_caption = ""
+
+ action = shared.opts.postprocessing_existing_caption_action
+ if action == 'Prepend' and existing_caption:
+ caption = f"{existing_caption} {pp.caption}"
+ elif action == 'Append' and existing_caption:
+ caption = f"{pp.caption} {existing_caption}"
+ elif action == 'Keep' and existing_caption:
+ caption = existing_caption
+ else:
+ caption = pp.caption
+
+ caption = caption.strip()
+ if caption:
+ with open(caption_filename, "w", encoding="utf8") as file:
+ file.write(caption)
+
+ if extras_mode != 2 or show_extras_results:
+ outputs.append(pp.image)
image_data.close()
devices.torch_gc()
-
+ shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), ''
+def run_postprocessing_webui(id_task, *args, **kwargs):
+ return run_postprocessing(*args, **kwargs)
+
+
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API"""
@@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
"upscaler_2_visibility": extras_upscaler_2_visibility,
},
"GFPGAN": {
+ "enable": True,
"gfpgan_visibility": gfpgan_visibility,
},
"CodeFormer": {
+ "enable": True,
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
},
diff --git a/modules/processing.py b/modules/processing.py
index 40598f5c..f55b85ed 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
-def apply_overlay(image, paste_loc, index, overlays):
- if overlays is None or index >= len(overlays):
- return image
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
- overlay = overlays[index]
+
+def apply_overlay(image, paste_loc, overlay):
+ if overlay is None:
+ return image
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
-def create_binary_mask(image):
+def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ if round:
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
@@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
+ sd = sd_model.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
+ image_conditioning = images_tensor_to_samples(image_conditioning,
+ approximation_indexes.get(opts.sd_vae_encode_method))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -296,7 +318,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
+ conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
return conditioning_image
@@ -308,7 +330,7 @@ class StableDiffusionProcessing:
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -320,8 +342,10 @@ class StableDiffusionProcessing:
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if round_image_mask:
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+
else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -345,7 +369,7 @@ class StableDiffusionProcessing:
return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,11 +381,17 @@ class StableDiffusionProcessing:
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
+ sd = self.sampler.model_wrap.inner_model.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -422,6 +452,8 @@ class StableDiffusionProcessing:
opts.sdxl_crop_top,
self.width,
self.height,
+ opts.fp8_storage,
+ opts.cache_fp16_weight,
)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
@@ -596,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans:
+
try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
- if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ if shared.opts.auto_vae_precision_bfloat16:
+ autofix_dtype = torch.bfloat16
+ autofix_dtype_text = "bfloat16"
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
+ autofix_dtype_comment = ""
+ elif shared.opts.auto_vae_precision:
+ autofix_dtype = torch.float32
+ autofix_dtype_text = "32-bit float"
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
+ else:
+ raise e
+
+ if devices.dtype_vae == autofix_dtype:
raise e
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
- "Web UI will now convert VAE into 32-bit float and retry.\n"
- "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
- "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)
- devices.dtype_vae = torch.float32
+ devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)
@@ -679,8 +724,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
- "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
- "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
+ "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
+ "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -699,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"User": p.user if opts.add_user_name_to_info else None,
}
- generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
@@ -799,7 +846,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -819,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.skipped:
state.skipped = False
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
sd_models.reload_model_weights() # model can be changed for example by refiner
@@ -865,15 +911,47 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+ if p.scripts is not None:
+ ps = scripts.PostSampleArgs(samples_ddim)
+ p.scripts.post_sample(p, ps)
+ samples_ddim = ps.samples
+
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
-
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -886,6 +964,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
+ state.nextjob()
+
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
@@ -922,13 +1002,31 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
+
+ mask_for_overlay = getattr(p, "mask_for_overlay", None)
+ overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None
+
+ if p.scripts is not None:
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
+ p.scripts.postprocess_maskoverlay(p, ppmo)
+ mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
- image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- image = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
+
+ if p.paste_to is not None:
+ original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
+
+ image = apply_overlay(image, p.paste_to, overlay_image)
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -938,28 +1036,26 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
- if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
- image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
-
- if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
- if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
-
- if opts.return_mask:
- output_images.append(image_mask)
-
- if opts.return_mask_composite:
- output_images.append(image_mask_composite)
+ if mask_for_overlay is not None:
+ if opts.return_mask or opts.save_mask:
+ image_mask = mask_for_overlay.convert('RGB')
+ if save_samples and opts.save_mask:
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
+ if opts.return_mask:
+ output_images.append(image_mask)
+
+ if opts.return_mask_composite or opts.save_mask_composite:
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ if save_samples and opts.save_mask_composite:
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
+ if opts.return_mask_composite:
+ output_images.append(image_mask_composite)
del x_samples_ddim
devices.torch_gc()
- state.nextjob()
-
if not infotexts:
infotexts.append(Processed(p, []).infotext(p, 0))
@@ -1028,6 +1124,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_sampler_name: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
+ force_task_id: str = None
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
@@ -1100,7 +1197,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if self.hr_checkpoint_name:
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None:
@@ -1147,6 +1244,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr:
return samples
+ devices.torch_gc()
if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
@@ -1156,8 +1254,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
- devices.torch_gc()
-
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@@ -1165,7 +1261,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
self.is_hr_pass = True
-
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
@@ -1254,7 +1349,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False
-
return decoded_samples
def close(self):
@@ -1357,12 +1451,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
+ mask_round: bool = True
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None
latent_mask: Image = None
+ force_task_id: str = None
image_mask: Any = field(default=None, init=False)
@@ -1402,7 +1498,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
- image_mask = create_binary_mask(image_mask)
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1448,7 +1544,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
- images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
image = images.flatten(img, opts.img2img_background_color)
@@ -1509,7 +1605,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- latmask = np.around(latmask)
+ if self.mask_round:
+ latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1521,7 +1618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
@@ -1533,7 +1630,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
- samples = samples * self.nmask + self.init_latent * self.mask
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
+
+ if self.scripts is not None:
+ mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
+ self.scripts.on_mask_blend(self, mba)
+ blended_samples = mba.blended_latent
+
+ samples = blended_samples
del x
devices.torch_gc()
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 29ccb78f..e9941413 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -1,6 +1,7 @@
import gradio as gr
from modules import scripts, sd_models
+from modules.infotext import PasteField
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
@@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
return None if info is None else info.title
self.infotext_fields = [
- (enable_refiner, lambda d: 'Refiner' in d),
- (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
- (refiner_switch_at, 'Refiner switch at'),
+ PasteField(enable_refiner, lambda d: 'Refiner' in d),
+ PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
+ PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index dc9c2da5..60293278 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -3,6 +3,7 @@ import json
import gradio as gr
from modules import scripts, ui, errors
+from modules.infotext import PasteField
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
@@ -51,12 +52,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
- (self.seed, "Seed"),
- (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
+ PasteField(self.seed, "Seed", api="seed"),
+ PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
+ PasteField(subseed, "Variation seed", api="subseed"),
+ PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
+ PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
+ PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
diff --git a/modules/progress.py b/modules/progress.py
index 69921de7..85255e82 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
from modules.shared import opts
import modules.shared as shared
-
+from collections import OrderedDict
+import string
+import random
+from typing import List
current_task = None
-pending_tasks = {}
+pending_tasks = OrderedDict()
finished_tasks = []
recorded_results = []
recorded_results_limit = 2
@@ -34,6 +37,11 @@ def finish_task(id_task):
if len(finished_tasks) > 16:
finished_tasks.pop(0)
+def create_task_id(task_type):
+ N = 7
+ res = ''.join(random.choices(string.ascii_uppercase +
+ string.digits, k=N))
+ return f"task({task_type}-{res})"
def record_results(id_task, res):
recorded_results.append((id_task, res))
@@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()
+class PendingTasksResponse(BaseModel):
+ size: int = Field(title="Pending task size")
+ tasks: List[str] = Field(title="Pending task ids")
class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app):
+ app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+def get_pending_tasks():
+ pending_tasks_ids = list(pending_tasks)
+ pending_len = len(pending_tasks_ids)
+ return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
+
+
def progressapi(req: ProgressRequest):
active = req.id_task == current_task
queued = req.id_task in pending_tasks
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index ddf4d2dd..cba13455 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -4,7 +4,7 @@ import re
from collections import namedtuple
import lark
-# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 02841c30..4d35b695 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,12 +1,9 @@
import os
-import numpy as np
-from PIL import Image
-from realesrgan import RealESRGANer
-
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN"
self.user_path = path
super().__init__()
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
- from realesrgan import RealESRGANer # noqa: F401
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
- self.enable = True
- self.scalers = []
- scalers = self.load_models(path)
+ self.enable = True
+ self.scalers = []
+ scalers = get_realesrgan_models(self)
- local_model_paths = self.find_models(ext_filter=[".pth"])
- for scaler in scalers:
- if scaler.local_data_path.startswith("http"):
- filename = modelloader.friendly_name(scaler.local_data_path)
- local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
- if local_model_candidates:
- scaler.local_data_path = local_model_candidates[0]
+ local_model_paths = self.find_models(ext_filter=[".pth"])
+ for scaler in scalers:
+ if scaler.local_data_path.startswith("http"):
+ filename = modelloader.friendly_name(scaler.local_data_path)
+ local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
+ if local_model_candidates:
+ scaler.local_data_path = local_model_candidates[0]
- if scaler.name in opts.realesrgan_enabled_models:
- self.scalers.append(scaler)
-
- except Exception:
- errors.report("Error importing Real-ESRGAN", exc_info=True)
- self.enable = False
- self.scalers = []
+ if scaler.name in opts.realesrgan_enabled_models:
+ self.scalers.append(scaler)
def do_upscale(self, img, path):
if not self.enable:
@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
- upsampler = RealESRGANer(
- scale=info.scale,
- model_path=info.local_data_path,
- model=info.model(),
- half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
- tile=opts.ESRGAN_tile,
- tile_pad=opts.ESRGAN_tile_overlap,
+ model_descriptor = modelloader.load_spandrel_model(
+ info.local_data_path,
device=self.device,
+ half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
+ )
+ return upscale_with_model(
+ model_descriptor,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ # TODO: `outscale`?
)
-
- upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
-
- image = Image.fromarray(upsampled)
- return image
def load_model(self, path):
for scaler in self.scalers:
@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler
raise ValueError(f"Unable to find model info: {path}")
- def load_models(self, _):
- return get_realesrgan_models(self)
-
-def get_realesrgan_models(scaler):
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
- models = [
- UpscalerData(
- name="R-ESRGAN General 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN General WDN 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN AnimeVideo",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN 4x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 4x+ Anime6B",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 2x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- scale=2,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
- ),
- ]
- return models
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
+def get_realesrgan_models(scaler: UpscalerRealESRGAN):
+ return [
+ UpscalerData(
+ name="R-ESRGAN General 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN General WDN 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN AnimeVideo",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 4x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 4x+ Anime6B",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 2x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ scale=2,
+ upscaler=scaler,
+ ),
+ ]
diff --git a/modules/rng.py b/modules/rng.py
index 9e8ba2ee..8934d39b 100644
--- a/modules/rng.py
+++ b/modules/rng.py
@@ -110,7 +110,7 @@ class ImageRNG:
self.is_first = True
def first(self):
- noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
+ noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = []
diff --git a/modules/scripts.py b/modules/scripts.py
index 5c6e0226..017aed5a 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()
+class MaskBlendArgs:
+ def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
+ self.current_latent = current_latent
+ self.nmask = nmask
+ self.init_latent = init_latent
+ self.mask = mask
+ self.blended_latent = blended_latent
+
+ self.denoiser = denoiser
+ self.is_final_blend = denoiser is None
+ self.sigma = sigma
+
+class PostSampleArgs:
+ def __init__(self, samples):
+ self.samples = samples
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
+class PostProcessMaskOverlayArgs:
+ def __init__(self, index, mask_for_overlay, overlay_image):
+ self.index = index
+ self.mask_for_overlay = mask_for_overlay
+ self.overlay_image = overlay_image
class PostprocessBatchListArgs:
def __init__(self, images):
@@ -206,6 +226,25 @@ class Script:
pass
+ def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
+ """
+ Called in inpainting mode when the original content is blended with the inpainted content.
+ This is called at every step in the denoising process and once at the end.
+ If is_final_blend is true, this is called for the final blending stage.
+ Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
+ """
+
+ pass
+
+ def post_sample(self, p, ps: PostSampleArgs, *args):
+ """
+ Called after the samples have been generated,
+ but before they have been decoded by the VAE, if applicable.
+ Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -213,6 +252,13 @@ class Script:
pass
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -311,20 +357,113 @@ scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
+def topological_sort(dependencies):
+ """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
+ Ignores errors relating to missing dependeencies or circular dependencies
+ """
+
+ visited = {}
+ result = []
+
+ def inner(name):
+ visited[name] = True
+
+ for dep in dependencies.get(name, []):
+ if dep in dependencies and dep not in visited:
+ inner(dep)
+
+ result.append(name)
+
+ for depname in dependencies:
+ if depname not in visited:
+ inner(depname)
+
+ return result
+
+
+@dataclass
+class ScriptWithDependencies:
+ script_canonical_name: str
+ file: ScriptFile
+ requires: list
+ load_before: list
+ load_after: list
+
def list_scripts(scriptdirname, extension, *, include_extensions=True):
- scripts_list = []
+ scripts = {}
+
+ loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
+ loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
- basedir = os.path.join(paths.script_path, scriptdirname)
- if os.path.exists(basedir):
- for filename in sorted(os.listdir(basedir)):
- scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
+ # build script dependency map
+ root_script_basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(root_script_basedir):
+ for filename in sorted(os.listdir(root_script_basedir)):
+ if not os.path.isfile(os.path.join(root_script_basedir, filename)):
+ continue
+
+ if os.path.splitext(filename)[1].lower() != extension:
+ continue
+
+ script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
+ scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
if include_extensions:
for ext in extensions.active():
- scripts_list += ext.list_files(scriptdirname, extension)
-
- scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+ extension_scripts_list = ext.list_files(scriptdirname, extension)
+ for extension_script in extension_scripts_list:
+ if not os.path.isfile(extension_script.path):
+ continue
+
+ script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
+ relative_path = scriptdirname + "/" + extension_script.filename
+
+ script = ScriptWithDependencies(
+ script_canonical_name=script_canonical_name,
+ file=extension_script,
+ requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
+ load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
+ load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
+ )
+
+ scripts[script_canonical_name] = script
+ loaded_extensions_scripts[ext.canonical_name].append(script)
+
+ for script_canonical_name, script in scripts.items():
+ # load before requires inverse dependency
+ # in this case, append the script name into the load_after list of the specified script
+ for load_before in script.load_before:
+ # if this requires an individual script to be loaded before
+ other_script = scripts.get(load_before)
+ if other_script:
+ other_script.load_after.append(script_canonical_name)
+
+ # if this requires an extension
+ other_extension_scripts = loaded_extensions_scripts.get(load_before)
+ if other_extension_scripts:
+ for other_script in other_extension_scripts:
+ other_script.load_after.append(script_canonical_name)
+
+ # if After mentions an extension, remove it and instead add all of its scripts
+ for load_after in list(script.load_after):
+ if load_after not in scripts and load_after in loaded_extensions_scripts:
+ script.load_after.remove(load_after)
+
+ for other_script in loaded_extensions_scripts.get(load_after, []):
+ script.load_after.append(other_script.script_canonical_name)
+
+ dependencies = {}
+
+ for script_canonical_name, script in scripts.items():
+ for required_script in script.requires:
+ if required_script not in scripts and required_script not in loaded_extensions:
+ errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
+
+ dependencies[script_canonical_name] = script.load_after
+
+ ordered_scripts = topological_sort(dependencies)
+ scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
return scripts_list
@@ -365,15 +504,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
- def orderby(basedir):
- # 1st webui, 2nd extensions-builtin, 3rd extensions
- priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
- for key in priority:
- if basedir.startswith(key):
- return priority[key]
- return 9999
-
- for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
+ # here the scripts_list is already ordered
+ # processing_script is not considered though
+ for scriptfile in scripts_list:
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
@@ -433,7 +566,12 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_data in auto_processing_scripts + scripts_data:
- script = script_data.script_class()
+ try:
+ script = script_data.script_class()
+ except Exception:
+ errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
+ continue
+
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -473,17 +611,25 @@ class ScriptRunner:
on_after.clear()
def create_script_ui(self, script):
- import modules.api.models as api_models
script.args_from = len(self.inputs)
script.args_to = len(self.inputs)
+ try:
+ self.create_script_ui_inner(script)
+ except Exception:
+ errors.report(f"Error creating UI for {script.name}: ", exc_info=True)
+
+ def create_script_ui_inner(self, script):
+ import modules.api.models as api_models
+
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None:
return
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+
api_args = []
for control in controls:
@@ -550,6 +696,8 @@ class ScriptRunner:
self.setup_ui_for_section(None, self.selectable_scripts)
def select_script(script_index):
+ if script_index is None:
+ script_index = 0
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
@@ -593,7 +741,7 @@ class ScriptRunner:
def run(self, p, *args):
script_index = args[0]
- if script_index == 0:
+ if script_index == 0 or script_index is None:
return None
script = self.selectable_scripts[script_index-1]
@@ -672,6 +820,22 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+ def post_sample(self, p, ps: PostSampleArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.post_sample(p, ps, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
+ def on_mask_blend(self, p, mba: MaskBlendArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.on_mask_blend(p, mba, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
@@ -680,6 +844,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_maskoverlay(p, ppmo, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index bac1335d..901cad08 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -1,13 +1,56 @@
+import dataclasses
import os
import gradio as gr
from modules import errors, shared
+@dataclasses.dataclass
+class PostprocessedImageSharedInfo:
+ target_width: int = None
+ target_height: int = None
+
+
class PostprocessedImage:
def __init__(self, image):
self.image = image
self.info = {}
+ self.shared = PostprocessedImageSharedInfo()
+ self.extra_images = []
+ self.nametags = []
+ self.disable_processing = False
+ self.caption = None
+
+ def get_suffix(self, used_suffixes=None):
+ used_suffixes = {} if used_suffixes is None else used_suffixes
+ suffix = "-".join(self.nametags)
+ if suffix:
+ suffix = "-" + suffix
+
+ if suffix not in used_suffixes:
+ used_suffixes[suffix] = 1
+ return suffix
+
+ for i in range(1, 100):
+ proposed_suffix = suffix + "-" + str(i)
+
+ if proposed_suffix not in used_suffixes:
+ used_suffixes[proposed_suffix] = 1
+ return proposed_suffix
+
+ return suffix
+
+ def create_copy(self, new_image, *, nametags=None, disable_processing=False):
+ pp = PostprocessedImage(new_image)
+ pp.shared = self.shared
+ pp.nametags = self.nametags.copy()
+ pp.info = self.info.copy()
+ pp.disable_processing = disable_processing
+
+ if nametags is not None:
+ pp.nametags += nametags
+
+ return pp
class ScriptPostprocessing:
@@ -42,10 +85,17 @@ class ScriptPostprocessing:
pass
- def image_changed(self):
- pass
+ def process_firstpass(self, pp: PostprocessedImage, **args):
+ """
+ Called for all scripts before calling process(). Scripts can examine the image here and set fields
+ of the pp object to communicate things to other scripts.
+ args contains a dictionary with all values returned by components from ui()
+ """
+ pass
+ def image_changed(self):
+ pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
return inputs
def run(self, pp: PostprocessedImage, args):
- for script in self.scripts_in_preferred_order():
- shared.state.job = script.name
+ scripts = []
+ for script in self.scripts_in_preferred_order():
script_args = args[script.args_from:script.args_to]
process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value
- script.process(pp, **process_args)
+ scripts.append((script, process_args))
+
+ for script, process_args in scripts:
+ script.process_firstpass(pp, **process_args)
+
+ all_images = [pp]
+
+ for script, process_args in scripts:
+ if shared.state.skipped:
+ break
+
+ shared.state.job = script.name
+
+ for single_image in all_images.copy():
+
+ if not single_image.disable_processing:
+ script.process(single_image, **process_args)
+
+ for extra_image in single_image.extra_images:
+ if not isinstance(extra_image, PostprocessedImage):
+ extra_image = single_image.create_copy(extra_image)
+
+ all_images.append(extra_image)
+
+ single_image.extra_images.clear()
+
+ pp.extra_images = all_images[1:]
def create_args_for_run(self, scripts_args):
if not self.ui_created:
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 8863107a..273a7edd 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
would be on the meta device.
"""
- if state_dict == sd:
+ if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bc5fbcd3..e139d996 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -38,8 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
-ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
-sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
+ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
+
+sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
+
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
@@ -184,6 +188,20 @@ class StableDiffusionModelHijack:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
+ def convert_sdxl_to_ssd(self, m):
+ """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
+
+ delattr(m.model.diffusion_model.middle_block, '1')
+ delattr(m.model.diffusion_model.middle_block, '2')
+ for i in ['9', '8', '7', '6', '5', '4']:
+ delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
+ delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
+ devices.torch_gc()
+
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
@@ -242,8 +260,12 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ import modules.models.diffusion.ddpm_edit
+
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward
else:
@@ -285,8 +307,6 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
- sd_unet.original_forward = None
-
def apply_circular(self, enable):
if self.circular_enabled == enable:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3b6cdea1..50bc209e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -230,15 +230,19 @@ def select_checkpoint():
return checkpoint_info
-checkpoint_dict_replacements = {
+checkpoint_dict_replacements_sd1 = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
+checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
+ 'conditioner.embedders.0.': 'cond_stage_model.',
+}
+
-def transform_checkpoint_dict_key(k):
- for text, replacement in checkpoint_dict_replacements.items():
+def transform_checkpoint_dict_key(k, replacements):
+ for text, replacement in replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
@@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
+ is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
+
sd = {}
for k, v in pl_sd.items():
- new_key = transform_checkpoint_dict_key(k)
+ if is_sd2_turbo:
+ new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
+ else:
+ new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
if new_key is not None:
sd[new_key] = v
@@ -339,10 +348,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -352,10 +379,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
-
+ model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
+ if model.is_ssd:
+ sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
+
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
@@ -371,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half:
model.float()
+ model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
@@ -384,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
+ alphas_cumprod = model.alphas_cumprod
+ model.alphas_cumprod = None
model.half()
+ model.alphas_cumprod = alphas_cumprod
+ model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
@@ -392,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
+ if check_fp8(model):
+ devices.fp8 = True
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.data.clone().cpu().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.data.clone().cpu().half()
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -639,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
+ 'alphas_cumprod': None,
'': torch.float16,
}
@@ -734,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -746,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index deab2f6e..b38137eb 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
+config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
- return config_sdxl
+ if diffusion_model_input.shape[1] == 9:
+ return config_sdxl_inpainting
+ else:
+ return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py
index 5ffd2f4f..f911fbb6 100644
--- a/modules/sd_models_types.py
+++ b/modules/sd_models_types.py
@@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
"""structure with additional information about the file with model's weights"""
is_sdxl: bool
- """True if the model's architecture is SDXL"""
+ """True if the model's architecture is SDXL or SSD"""
+
+ is_ssd: bool
+ """True if the model is SSD"""
is_sd2: bool
"""True if the model's architecture is SD 2.x"""
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 01123321..0de17af3 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
+from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
+ sd = self.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ x = torch.cat([x] + cond['c_concat'], dim=1)
+
return self.model(x, t, cond)
@@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
- dtype = next(model.model.diffusion_model.parameters()).dtype
+ dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
@@ -93,7 +100,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..eb9d5daf 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # If we use masks, blending between the denoised and original latent images occurs here.
+ def apply_blend(current_latent):
+ blended_latent = current_latent * self.nmask + self.init_latent * self.mask
+
+ if self.p.scripts is not None:
+ from modules import scripts
+ mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
+ self.p.scripts.on_mask_blend(self.p, mba)
+ blended_latent = mba.blended_latent
+
+ return blended_latent
+
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ x = apply_blend(x)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py
index 1b981ca8..72fd0aa5 100644
--- a/modules/sd_samplers_extra.py
+++ b/modules/sd_samplers_extra.py
@@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0:
restart_times -= 1
- step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
+ step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
last_sigma = None
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index b17a8f93..777dd8d0 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
@@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
index a72daafd..930a64af 100644
--- a/modules/sd_samplers_timesteps_impl.py
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
- alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
@@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
- alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 6a7bc9e2..a771849c 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
-original_forward = None
-
+original_forward = None # not used, only left temporarily for compatibility
def list_unets():
new_unets = script_callbacks.list_unets_callback()
@@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
pass
-def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
- if current_unet is not None:
- return current_unet.forward(x, timesteps, context, *args, **kwargs)
+def create_unet_forward(original_forward):
+ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return original_forward(self, x, timesteps, context, *args, **kwargs)
- return original_forward(self, x, timesteps, context, *args, **kwargs)
+ return UNetModel_forward
diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py
index 822db0a9..b6dc3145 100644
--- a/modules/shared_gradio_themes.py
+++ b/modules/shared_gradio_themes.py
@@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None):
except Exception as e:
errors.display(e, "changing gradio theme")
shared.gradio_theme = gr.themes.Default(**default_theme_args)
+
+ # append additional values gradio_theme
+ shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity
+ shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity
diff --git a/modules/shared_items.py b/modules/shared_items.py
index b1459f8c..e1392472 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -66,7 +66,25 @@ def reload_hypernetworks():
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+def get_infotext_names():
+ from modules import infotext, shared
+ res = {}
+
+ for info in shared.opts.data_labels.values():
+ if info.infotext:
+ res[info.infotext] = 1
+
+ for tab_data in infotext.paste_fields.values():
+ for _, name in tab_data.get("fields") or []:
+ if isinstance(name, str):
+ res[name] = 1
+
+ return list(res)
+
+
ui_reorder_categories_builtin_items = [
+ "prompt",
+ "image",
"inpaint",
"sampler",
"accordions",
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 0a82216f..cca3f7be 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -1,9 +1,10 @@
+import os
import gradio as gr
-from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts
-from modules.options import options_section, OptionInfo, OptionHTML
+from modules.options import options_section, OptionInfo, OptionHTML, categories
options_templates = {}
hide_dirs = shared.hide_dirs
@@ -21,7 +22,14 @@ restricted_opts = {
"outdir_init_images"
}
-options_templates.update(options_section(('saving-images', "Saving images/grids"), {
+categories.register_category("saving", "Saving images")
+categories.register_category("sd", "Stable Diffusion")
+categories.register_category("ui", "User Interface")
+categories.register_category("system", "System")
+categories.register_category("postprocessing", "Postprocessing")
+categories.register_category("training", "Training")
+
+options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
@@ -39,8 +47,6 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
- "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
- "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
@@ -64,21 +70,22 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
"notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
+ "notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
}))
-options_templates.update(options_section(('saving-paths', "Paths for saving"), {
+options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
- "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
- "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
- "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
+ "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),
+ "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),
+ "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
- "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
- "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
- "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
- "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
+ "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),
+ "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),
+ "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs),
+ "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
-options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
+options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
@@ -86,21 +93,21 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
-options_templates.update(options_section(('upscaling', "Upscaling"), {
+options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
}))
-options_templates.update(options_section(('face-restoration', "Face restoration"), {
+options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
-options_templates.update(options_section(('system', "System"), {
+options_templates.update(options_section(('system', "System", "system"), {
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
@@ -115,13 +122,13 @@ options_templates.update(options_section(('system', "System"), {
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
}))
-options_templates.update(options_section(('API', "API"), {
+options_templates.update(options_section(('API', "API", "system"), {
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
}))
-options_templates.update(options_section(('training', "Training"), {
+options_templates.update(options_section(('training', "Training", "training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
@@ -136,7 +143,7 @@ options_templates.update(options_section(('training', "Training"), {
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
}))
-options_templates.update(options_section(('sd', "Stable Diffusion"), {
+options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
@@ -153,14 +160,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
}))
-options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
+options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
-options_templates.update(options_section(('vae', "VAE"), {
+options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML("""
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
@@ -170,12 +177,13 @@ For img2img, VAE is used to process user's input image before the sampling, and
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
+ "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"),
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
}))
-options_templates.update(options_section(('img2img', "img2img"), {
+options_templates.update(options_section(('img2img', "img2img", "sd"), {
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
@@ -188,9 +196,10 @@ options_templates.update(options_section(('img2img', "img2img"), {
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
+ "img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
}))
-options_templates.update(options_section(('optimizations', "Optimizations"), {
+options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
@@ -199,9 +208,12 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+ "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
+ "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
-options_templates.update(options_section(('compatibility', "Compatibility"), {
+options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
+ "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
@@ -209,6 +221,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
+ "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
}))
options_templates.update(options_section(('interrogate', "Interrogate"), {
@@ -226,14 +239,17 @@ options_templates.update(options_section(('interrogate', "Interrogate"), {
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
}))
-options_templates.update(options_section(('extra_networks', "Extra Networks"), {
+options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), {
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
+ "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."),
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
+ "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
+ "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
@@ -241,44 +257,69 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
}))
-options_templates.update(options_section(('ui', "User interface"), {
- "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
- "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
- "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
- "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
- "return_grid": OptionInfo(True, "Show grid in results for web"),
- "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
- "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
- "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
- "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
- "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
- "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
- "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
- "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
- "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
- "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"),
+options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), {
+ "keyedit_precision_attention": OptionInfo(0.1, "Precision for (attention:1.1) when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
- "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
- "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
- "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
- "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
+}))
+
+options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
+ "return_grid": OptionInfo(True, "Show grid in gallery"),
+ "do_not_show_images": OptionInfo(False, "Do not show any images in gallery"),
+ "js_modal_lightbox": OptionInfo(True, "Full page image viewer: enable"),
+ "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"),
+ "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"),
+ "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"),
+ "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
+ "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
+ "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
+}))
+
+options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
+ "compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
- "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
+ "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
+ "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
+ "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"),
+}))
+
+options_templates.update(options_section(('ui', "User interface", "ui"), {
+ "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
+ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
+ "ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
+ "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
+ "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
+ "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
+ "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
}))
-options_templates.update(options_section(('infotext', "Infotext"), {
- "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
- "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
- "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
+options_templates.update(options_section(('infotext', "Infotext", "ui"), {
+ "infotext_explanation": OptionHTML("""
+Infotext is what this software calls the text that contains generation parameters and can be used to generate the same picture again.
+It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button.
+"""),
+ "enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"),
+ "save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"),
+
+ "add_model_name_to_info": OptionInfo(True, "Add model name to infotext"),
+ "add_model_hash_to_info": OptionInfo(True, "Add model hash to infotext"),
+ "add_vae_name_to_info": OptionInfo(True, "Add VAE name to infotext"),
+ "add_vae_hash_to_info": OptionInfo(True, "Add VAE hash to infotext"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to infotext when authenticated"),
+ "add_version_to_infotext": OptionInfo(True, "Add program version to infotext"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
+ "infotext_skip_pasting": OptionInfo([], "Disregard fields from pasted infotext", ui_components.DropdownMulti, lambda: {"choices": shared_items.get_infotext_names()}),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
<li>Ignore: keep prompt and styles dropdown as it is.</li>
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
@@ -288,7 +329,7 @@ options_templates.update(options_section(('infotext', "Infotext"), {
}))
-options_templates.update(options_section(('ui', "Live previews"), {
+options_templates.update(options_section(('ui', "Live previews", "ui"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
@@ -299,9 +340,10 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
+ "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"),
}))
-options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
+options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
@@ -321,12 +363,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
+ 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
}))
-options_templates.update(options_section(('postprocessing', "Postprocessing"), {
+options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ 'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"),
}))
options_templates.update(options_section((None, "Hidden options"), {
diff --git a/modules/shared_state.py b/modules/shared_state.py
index a68789cc..33996691 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
class State:
skipped = False
interrupted = False
+ stopping_generation = False
job = ""
job_no = 0
job_count = 0
@@ -79,6 +80,10 @@ class State:
self.interrupted = True
log.info("Received interrupt request")
+ def stop_generating(self):
+ self.stopping_generation = True
+ log.info("Received stop generating request")
+
def nextjob(self):
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
@@ -91,6 +96,7 @@ class State:
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
+ "stopping_generation": self.stopping_generation,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
@@ -114,6 +120,7 @@ class State:
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
+ self.stopping_generation = False
self.textinfo = None
self.job = job
devices.torch_gc()
diff --git a/modules/styles.py b/modules/styles.py
index 0740fe1b..026c4300 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,7 +1,7 @@
import csv
+import fnmatch
import os
import os.path
-import re
import typing
import shutil
@@ -10,6 +10,7 @@ class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
+ path: str = None
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -29,12 +30,17 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
-re_spaces = re.compile(" +")
+def extract_style_text_from_prompt(style_text, prompt):
+ """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
+ extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
+ """
+
+ stripped_prompt = prompt.strip()
+ stripped_style_text = style_text.strip()
-def extract_style_text_from_prompt(style_text, prompt):
- stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
- stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
if "{prompt}" in stripped_style_text:
left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
@@ -52,7 +58,12 @@ def extract_style_text_from_prompt(style_text, prompt):
return False, prompt
-def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
+def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
+ """
+ Takes a style and compares it to the prompt and negative prompt. If the style
+ matches, returns True plus the prompt and negative prompt with the style text
+ removed. Otherwise, returns False with the original prompt and negative prompt.
+ """
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
@@ -69,25 +80,84 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
def __init__(self, path: str):
- self.no_style = PromptStyle("None", "", "")
+ self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path
+ folder, file = os.path.split(self.path)
+ filename, _, ext = file.partition('*')
+ self.default_path = os.path.join(folder, filename + ext)
+
+ self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
+
self.reload()
def reload(self):
+ """
+ Clears the style database and reloads the styles from the CSV file(s)
+ matching the path used to initialize the database.
+ """
self.styles.clear()
- if not os.path.exists(self.path):
+ path, filename = os.path.split(self.path)
+
+ if "*" in filename:
+ fileglob = filename.split("*")[0] + "*.csv"
+ filelist = []
+ for file in os.listdir(path):
+ if fnmatch.fnmatch(file, fileglob):
+ filelist.append(file)
+ # Add a visible divider to the style list
+ half_len = round(len(file) / 2)
+ divider = f"{'-' * (20 - half_len)} {file.upper()}"
+ divider = f"{divider} {'-' * (40 - len(divider))}"
+ self.styles[divider] = PromptStyle(
+ f"{divider}", None, None, "do_not_save"
+ )
+ # Add styles from this CSV file
+ self.load_from_csv(os.path.join(path, file))
+ if len(filelist) == 0:
+ print(f"No styles found in {path} matching {fileglob}")
+ return
+ elif not os.path.exists(self.path):
+ print(f"Style database not found: {self.path}")
return
+ else:
+ self.load_from_csv(self.path)
- with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
+ def load_from_csv(self, path: str):
+ with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
+ # Ignore empty rows or rows starting with a comment
+ if not row or row["name"].startswith("#"):
+ continue
# Support loading old CSV format with "name, text"-columns
prompt = row["prompt"] if "prompt" in row else row["text"]
negative_prompt = row.get("negative_prompt", "")
- self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+ # Add style to database
+ self.styles[row["name"]] = PromptStyle(
+ row["name"], prompt, negative_prompt, path
+ )
+
+ def get_style_paths(self) -> set:
+ """Returns a set of all distinct paths of files that styles are loaded from."""
+ # Update any styles without a path to the default path
+ for style in list(self.styles.values()):
+ if not style.path:
+ self.styles[style.name] = style._replace(path=self.default_path)
+
+ # Create a list of all distinct paths, including the default path
+ style_paths = set()
+ style_paths.add(self.default_path)
+ for _, style in self.styles.items():
+ if style.path:
+ style_paths.add(style.path)
+
+ # Remove any paths for styles that are just list dividers
+ style_paths.discard("do_not_save")
+
+ return style_paths
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -96,20 +166,40 @@ class StyleDatabase:
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
def apply_styles_to_prompt(self, prompt, styles):
- return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
+ return apply_styles_to_prompt(
+ prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
+ )
def apply_negative_styles_to_prompt(self, prompt, styles):
- return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
-
- def save_styles(self, path: str) -> None:
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.copy(path, f"{path}.bak")
-
- with open(path, "w", encoding="utf-8-sig", newline='') as file:
- writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
+ return apply_styles_to_prompt(
+ prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
+ )
+
+ def save_styles(self, path: str = None) -> None:
+ # The path argument is deprecated, but kept for backwards compatibility
+ _ = path
+
+ style_paths = self.get_style_paths()
+
+ csv_names = [os.path.split(path)[1].lower() for path in style_paths]
+
+ for style_path in style_paths:
+ # Always keep a backup file around
+ if os.path.exists(style_path):
+ shutil.copy(style_path, f"{style_path}.bak")
+
+ # Write the styles to the CSV file
+ with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
+ writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
+ writer.writeheader()
+ for style in (s for s in self.styles.values() if s.path == style_path):
+ # Skip style list dividers, e.g. "STYLES.CSV"
+ if style.name.lower().strip("# ") in csv_names:
+ continue
+ # Write style fields, ignoring the path field
+ writer.writerow(
+ {k: v for k, v in style._asdict().items() if k != "path"}
+ )
def extract_styles_from_prompt(self, prompt, negative_prompt):
extracted = []
@@ -120,7 +210,9 @@ class StyleDatabase:
found_style = None
for style in applicable_styles:
- is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
+ is_match, new_prompt, new_neg_prompt = extract_original_prompts(
+ style, prompt, negative_prompt
+ )
if is_match:
found_style = style
prompt = new_prompt
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index 2db7551d..5abf616b 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -1,7 +1,6 @@
import json
import os
import sys
-import traceback
import platform
import hashlib
@@ -27,11 +26,9 @@ environment_whitelist = {
"OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO",
- "CODEFORMER_REPO",
"BLIP_REPO",
"STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH",
- "CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS",
@@ -84,7 +81,7 @@ def get_dict():
"Checksum": checksum_token,
"Commandline": get_argv(),
"Torch env info": get_torch_sysinfo(),
- "Exceptions": get_exceptions(),
+ "Exceptions": errors.get_exceptions(),
"CPU": {
"model": platform.processor(),
"count logical": psutil.cpu_count(logical=True),
@@ -104,21 +101,6 @@ def get_dict():
return res
-def format_traceback(tb):
- return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
-
-
-def format_exception(e, tb):
- return {"exception": str(e), "traceback": format_traceback(tb)}
-
-
-def get_exceptions():
- try:
- return list(reversed(errors.exception_records))
- except Exception as e:
- return str(e)
-
-
def get_environment():
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 1675e39a..e223a2e0 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -3,6 +3,8 @@ import requests
import os
import numpy as np
from PIL import ImageDraw
+from modules import paths_internal
+from pkg_resources import parse_version
GREEN = "#0F0"
BLUE = "#00F"
@@ -25,7 +27,6 @@ def crop_image(im, settings):
elif is_portrait(settings.crop_width, settings.crop_height):
scale_by = settings.crop_height / im.height
-
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
@@ -69,6 +70,7 @@ def crop_image(im, settings):
return results
+
def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
@@ -78,118 +80,120 @@ def focal_point(im, settings):
weight_pref_total = 0
if corner_points:
- weight_pref_total += settings.corner_points_weight
+ weight_pref_total += settings.corner_points_weight
if entropy_points:
- weight_pref_total += settings.entropy_points_weight
+ weight_pref_total += settings.entropy_points_weight
if face_points:
- weight_pref_total += settings.face_points_weight
+ weight_pref_total += settings.face_points_weight
corner_centroid = None
if corner_points:
- corner_centroid = centroid(corner_points)
- corner_centroid.weight = settings.corner_points_weight / weight_pref_total
- pois.append(corner_centroid)
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
entropy_centroid = None
if entropy_points:
- entropy_centroid = centroid(entropy_points)
- entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
- pois.append(entropy_centroid)
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
face_centroid = None
if face_points:
- face_centroid = centroid(face_points)
- face_centroid.weight = settings.face_points_weight / weight_pref_total
- pois.append(face_centroid)
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
average_point = poi_average(pois, settings)
if settings.annotate_image:
- d = ImageDraw.Draw(im)
- max_size = min(im.width, im.height) * 0.07
- if corner_centroid is not None:
- color = BLUE
- box = corner_centroid.bounding(max_size * corner_centroid.weight)
- d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
- d.ellipse(box, outline=color)
- if len(corner_points) > 1:
- for f in corner_points:
- d.rectangle(f.bounding(4), outline=color)
- if entropy_centroid is not None:
- color = "#ff0"
- box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
- d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
- d.ellipse(box, outline=color)
- if len(entropy_points) > 1:
- for f in entropy_points:
- d.rectangle(f.bounding(4), outline=color)
- if face_centroid is not None:
- color = RED
- box = face_centroid.bounding(max_size * face_centroid.weight)
- d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
- d.ellipse(box, outline=color)
- if len(face_points) > 1:
- for f in face_points:
- d.rectangle(f.bounding(4), outline=color)
-
- d.ellipse(average_point.bounding(max_size), outline=GREEN)
+ d = ImageDraw.Draw(im)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point
def image_face_points(im, settings):
if settings.dnn_model_path is not None:
- detector = cv2.FaceDetectorYN.create(
- settings.dnn_model_path,
- "",
- (im.width, im.height),
- 0.9, # score threshold
- 0.3, # nms threshold
- 5000 # keep top k before nms
- )
- faces = detector.detect(np.array(im))
- results = []
- if faces[1] is not None:
- for face in faces[1]:
- x = face[0]
- y = face[1]
- w = face[2]
- h = face[3]
- results.append(
- PointOfInterest(
- int(x + (w * 0.5)), # face focus left/right is center
- int(y + (h * 0.33)), # face focus up/down is close to the top of the head
- size = w,
- weight = 1/len(faces[1])
- )
- )
- return results
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.9, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
+ size=w,
+ weight=1 / len(faces[1])
+ )
+ )
+ return results
else:
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
- for t in tries:
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except Exception:
- continue
-
- if faces:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01],
+ [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05],
+ [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05],
+ [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05],
+ [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05],
+ [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05],
+ [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05],
+ [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize),
+ flags=cv2.CASCADE_SCALE_IMAGE)
+ except Exception:
+ continue
+
+ if faces:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]),
+ weight=1 / len(rects)) for r in rects]
return []
@@ -198,7 +202,7 @@ def image_corner_points(im, settings):
# naive attempt at preventing focal points from collecting at watermarks near the bottom
gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+ gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999")
np_im = np.array(grayscale)
@@ -206,7 +210,7 @@ def image_corner_points(im, settings):
np_im,
maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.06,
+ minDistance=min(grayscale.width, grayscale.height) * 0.06,
useHarrisDetector=False,
)
@@ -215,8 +219,8 @@ def image_corner_points(im, settings):
focal_points = []
for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points)))
return focal_points
@@ -225,13 +229,13 @@ def image_entropy_points(im, settings):
landscape = im.height < im.width
portrait = im.height > im.width
if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
+ move_idx = [0, 2]
+ move_max = im.size[0]
elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
+ move_idx = [1, 3]
+ move_max = im.size[1]
else:
- return []
+ return []
e_max = 0
crop_current = [0, 0, settings.crop_width, settings.crop_height]
@@ -241,14 +245,14 @@ def image_entropy_points(im, settings):
e = image_entropy(crop)
if (e > e_max):
- e_max = e
- crop_best = list(crop_current)
+ e_max = e
+ crop_best = list(crop_current)
crop_current[move_idx[0]] += 4
crop_current[move_idx[1]] += 4
- x_mid = int(crop_best[0] + settings.crop_width/2)
- y_mid = int(crop_best[1] + settings.crop_height/2)
+ x_mid = int(crop_best[0] + settings.crop_width / 2)
+ y_mid = int(crop_best[1] + settings.crop_height / 2)
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
@@ -294,22 +298,23 @@ def is_square(w, h):
return w == h
-def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
+model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
+if parse_version(cv2.__version__) >= parse_version('4.8'):
+ model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
+ model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
+else:
+ model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
+ model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- os.makedirs(dirname, exist_ok=True)
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url)
- with open(cache_file, "wb") as f:
+def download_and_cache_models():
+ if not os.path.exists(model_file_path):
+ os.makedirs(model_dir_opencv, exist_ok=True)
+ print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
+ response = requests.get(model_url)
+ with open(model_file_path, "wb") as f:
f.write(response.content)
-
- if os.path.exists(cache_file):
- return cache_file
- return None
+ return model_file_path
class PointOfInterest:
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
deleted file mode 100644
index dbd856bd..00000000
--- a/modules/textual_inversion/preprocess.py
+++ /dev/null
@@ -1,232 +0,0 @@
-import os
-from PIL import Image, ImageOps
-import math
-import tqdm
-
-from modules import paths, shared, images, deepbooru
-from modules.textual_inversion import autocrop
-
-
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
- try:
- if process_caption:
- shared.interrogator.load()
-
- if process_caption_deepbooru:
- deepbooru.model.start()
-
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
-
- finally:
-
- if process_caption:
- shared.interrogator.send_blip_to_ram()
-
- if process_caption_deepbooru:
- deepbooru.model.stop()
-
-
-def listfiles(dirname):
- return os.listdir(dirname)
-
-
-class PreprocessParams:
- src = None
- dstdir = None
- subindex = 0
- flip = False
- process_caption = False
- process_caption_deepbooru = False
- preprocess_txt_action = None
-
-
-def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
- caption = ""
-
- if params.process_caption:
- caption += shared.interrogator.generate_caption(image)
-
- if params.process_caption_deepbooru:
- if caption:
- caption += ", "
- caption += deepbooru.model.tag_multi(image)
-
- filename_part = params.src
- filename_part = os.path.splitext(filename_part)[0]
- filename_part = os.path.basename(filename_part)
-
- basename = f"{index:05}-{params.subindex}-{filename_part}"
- image.save(os.path.join(params.dstdir, f"{basename}.png"))
-
- if params.preprocess_txt_action == 'prepend' and existing_caption:
- caption = f"{existing_caption} {caption}"
- elif params.preprocess_txt_action == 'append' and existing_caption:
- caption = f"{caption} {existing_caption}"
- elif params.preprocess_txt_action == 'copy' and existing_caption:
- caption = existing_caption
-
- caption = caption.strip()
-
- if caption:
- with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
- file.write(caption)
-
- params.subindex += 1
-
-
-def save_pic(image, index, params, existing_caption=None):
- save_pic_with_caption(image, index, params, existing_caption=existing_caption)
-
- if params.flip:
- save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
-
-
-def split_pic(image, inverse_xy, width, height, overlap_ratio):
- if inverse_xy:
- from_w, from_h = image.height, image.width
- to_w, to_h = height, width
- else:
- from_w, from_h = image.width, image.height
- to_w, to_h = width, height
- h = from_h * to_w // from_w
- if inverse_xy:
- image = image.resize((h, to_w))
- else:
- image = image.resize((to_w, h))
-
- split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
- y_step = (h - to_h) / (split_count - 1)
- for i in range(split_count):
- y = int(y_step * i)
- if inverse_xy:
- splitted = image.crop((y, 0, y + to_h, to_w))
- else:
- splitted = image.crop((0, y, to_w, y + to_h))
- yield splitted
-
-# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
-def center_crop(image: Image, w: int, h: int):
- iw, ih = image.size
- if ih / h < iw / w:
- sw = w * ih / h
- box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
- else:
- sh = h * iw / w
- box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
- return image.resize((w, h), Image.Resampling.LANCZOS, box)
-
-
-def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
- iw, ih = image.size
- err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
- wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
- if minarea <= w * h <= maxarea and err(w, h) <= threshold),
- key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
- default=None
- )
- return wh and center_crop(image, *wh)
-
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
- width = process_width
- height = process_height
- src = os.path.abspath(process_src)
- dst = os.path.abspath(process_dst)
- split_threshold = max(0.0, min(1.0, split_threshold))
- overlap_ratio = max(0.0, min(0.9, overlap_ratio))
-
- assert src != dst, 'same directory specified as source and destination'
-
- os.makedirs(dst, exist_ok=True)
-
- files = listfiles(src)
-
- shared.state.job = "preprocess"
- shared.state.textinfo = "Preprocessing..."
- shared.state.job_count = len(files)
-
- params = PreprocessParams()
- params.dstdir = dst
- params.flip = process_flip
- params.process_caption = process_caption
- params.process_caption_deepbooru = process_caption_deepbooru
- params.preprocess_txt_action = preprocess_txt_action
-
- pbar = tqdm.tqdm(files)
- for index, imagefile in enumerate(pbar):
- params.subindex = 0
- filename = os.path.join(src, imagefile)
- try:
- img = Image.open(filename)
- img = ImageOps.exif_transpose(img)
- img = img.convert("RGB")
- except Exception:
- continue
-
- description = f"Preprocessing [Image {index}/{len(files)}]"
- pbar.set_description(description)
- shared.state.textinfo = description
-
- params.src = filename
-
- existing_caption = None
- existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
- if os.path.exists(existing_caption_filename):
- with open(existing_caption_filename, 'r', encoding="utf8") as file:
- existing_caption = file.read()
-
- if shared.state.interrupted:
- break
-
- if img.height > img.width:
- ratio = (img.width * height) / (img.height * width)
- inverse_xy = False
- else:
- ratio = (img.height * width) / (img.width * height)
- inverse_xy = True
-
- process_default_resize = True
-
- if process_split and ratio < 1.0 and ratio <= split_threshold:
- for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
- save_pic(splitted, index, params, existing_caption=existing_caption)
- process_default_resize = False
-
- if process_focal_crop and img.height != img.width:
-
- dnn_model_path = None
- try:
- dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
- except Exception as e:
- print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
-
- autocrop_settings = autocrop.Settings(
- crop_width = width,
- crop_height = height,
- face_points_weight = process_focal_crop_face_weight,
- entropy_points_weight = process_focal_crop_entropy_weight,
- corner_points_weight = process_focal_crop_edges_weight,
- annotate_image = process_focal_crop_debug,
- dnn_model_path = dnn_model_path,
- )
- for focal in autocrop.crop_image(img, autocrop_settings):
- save_pic(focal, index, params, existing_caption=existing_caption)
- process_default_resize = False
-
- if process_multicrop:
- cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
- if cropped is not None:
- save_pic(cropped, index, params, existing_caption=existing_caption)
- else:
- print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
- process_default_resize = False
-
- if process_keep_original_size:
- save_pic(img, index, params, existing_caption=existing_caption)
- process_default_resize = False
-
- if process_default_resize:
- img = images.resize_image(1, img, width, height)
- save_pic(img, index, params, existing_caption=existing_caption)
-
- shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 04dda585..c6bcab15 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,6 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
-from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
@@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
def tensorboard_setup(log_directory):
+ from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
@@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+ tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
+ try:
+ tensorboard_writer = tensorboard_setup(log_directory)
+ except ImportError:
+ errors.report("Error initializing tensorboard", exc_info=True)
pin_memory = shared.opts.pin_memory
@@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 35c4feef..f149ad1f 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -3,7 +3,6 @@ import html
import gradio as gr
import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
@@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
-def preprocess(*args):
- modules.textual_inversion.preprocess.preprocess(*args)
-
- return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
-
-
def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
diff --git a/modules/torch_utils.py b/modules/torch_utils.py
new file mode 100644
index 00000000..e5b52393
--- /dev/null
+++ b/modules/torch_utils.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import torch.nn
+
+
+def get_param(model) -> torch.nn.Parameter:
+ """
+ Find the first parameter in a model or module.
+ """
+ if hasattr(model, "model") and hasattr(model.model, "parameters"):
+ # Unpeel a model descriptor to get at the actual Torch module.
+ model = model.model
+
+ for param in model.parameters():
+ return param
+
+ raise ValueError(f"No parameters found in model {model!r}")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e4e18ceb..3a481915 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -2,7 +2,7 @@ from contextlib import closing
import modules.scripts
from modules import processing
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules.infotext import create_override_settings_dict
from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
diff --git a/modules/ui.py b/modules/ui.py
index bcf39199..378529c7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -4,6 +4,7 @@ import os
import sys
from functools import reduce
import warnings
+from contextlib import ExitStack
import gradio as gr
import gradio.utils
@@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -20,15 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript
from modules.shared import opts, cmd_opts
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext as parameters_copypaste
import modules.hypernetworks.ui as hypernetworks_ui
import modules.textual_inversion.ui as textual_inversion_ui
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
-import modules.images
from modules import prompt_parser
from modules.sd_hijack import model_hijack
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext import image_from_url_text, PasteField
create_setting_component = ui_settings.create_setting_component
@@ -177,80 +177,6 @@ def update_negative_prompt_token_counter(text, steps):
return update_token_counter(text, steps, is_positive=False)
-class Toprow:
- """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
-
- def __init__(self, is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
- self.id_part = id_part
-
- with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
- with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
- self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
- self.button_interrogate = None
- self.button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_classes="interrogate-col"):
- self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
- self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
- self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
- self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- self.skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
-
- self.interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
- with gr.Row(elem_id=f"{id_part}_tools"):
- self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
- self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt", tooltip="Clear prompt")
- self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
- self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False, tooltip="Restore progress")
-
- self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
- self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
- self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
- self.clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[self.prompt, self.negative_prompt],
- outputs=[self.prompt, self.negative_prompt],
- )
-
- self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
- self.ui_styles.setup_apply_button(self.apply_styles)
-
- self.prompt_img.change(
- fn=modules.images.image_data,
- inputs=[self.prompt_img],
- outputs=[self.prompt, self.prompt_img],
- show_progress=False,
- )
-
-
def setup_progressbar(*args, **kwargs):
pass
@@ -288,8 +214,8 @@ def apply_setting(key, value):
return getattr(opts, key)
-def create_output_panel(tabname, outdir):
- return ui_common.create_output_panel(tabname, outdir)
+def create_output_panel(tabname, outdir, toprow=None):
+ return ui_common.create_output_panel(tabname, outdir, toprow)
def create_sampler_and_steps_selection(choices, tabname):
@@ -336,7 +262,7 @@ def create_ui():
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- toprow = Toprow(is_img2img=False)
+ toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
dummy_component = gr.Label(visible=False)
@@ -344,10 +270,17 @@ def create_ui():
extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
- with gr.Column(variant='compact', elem_id="txt2img_settings"):
+ with ExitStack() as stack:
+ if shared.opts.txt2img_settings_accordion:
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
+ stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
+
scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
+ if category == "prompt":
+ toprow.create_inline_toprow_prompts()
+
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
@@ -442,7 +375,7 @@ def create_ui():
show_progress=False,
)
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
@@ -502,28 +435,28 @@ def create_ui():
)
txt2img_paste_fields = [
- (toprow.prompt, "Prompt"),
- (toprow.negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_name, "Sampler"),
- (cfg_scale, "CFG scale"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- (hr_checkpoint_name, "Hires checkpoint"),
- (hr_sampler_name, "Hires sampler"),
- (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
- (hr_prompt, "Hires prompt"),
- (hr_negative_prompt, "Hires negative prompt"),
- (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
+ PasteField(toprow.prompt, "Prompt", api="prompt"),
+ PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
+ PasteField(steps, "Steps", api="steps"),
+ PasteField(sampler_name, "Sampler", api="sampler_name"),
+ PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
+ PasteField(width, "Size-1", api="width"),
+ PasteField(height, "Size-2", api="height"),
+ PasteField(batch_size, "Batch size", api="batch_size"),
+ PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
+ PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
+ PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
+ PasteField(hr_scale, "Hires upscale", api="hr_scale"),
+ PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
+ PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
+ PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
+ PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
+ PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
+ PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
+ PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
+ PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
+ PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
+ PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
@@ -554,13 +487,17 @@ def create_ui():
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- toprow = Toprow(is_img2img=True)
+ toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
- with gr.Column(variant='compact', elem_id="img2img_settings"):
+ with ExitStack() as stack:
+ if shared.opts.img2img_settings_accordion:
+ stack.enter_context(gr.Accordion("Open for Settings", open=False))
+ stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
+
copy_image_buttons = []
copy_image_destinations = {}
@@ -577,85 +514,89 @@ def create_ui():
button = gr.Button(title)
copy_image_buttons.append((button, name, elem))
- with gr.Tabs(elem_id="mode_img2img"):
- img2img_selected_tab = gr.State(0)
-
- with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
- add_copy_image_controls('img2img', init_img)
-
- with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
- add_copy_image_controls('sketch', sketch)
-
- with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
- add_copy_image_controls('inpaint', init_img_with_mask)
-
- with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
- inpaint_color_sketch_orig = gr.State(None)
- add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
-
- def update_orig(image, state):
- if image is not None:
- same_size = state is not None and state.size == image.size
- has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
- edited = same_size and has_exact_match
- return image if not edited or state is None else state
-
- inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
-
- with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
- init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
- init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
-
- with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
- hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(
- "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
- "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
- f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
- f"{hidden}</p>"
- )
- img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
- img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
- img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
- with gr.Accordion("PNG info", open=False):
- img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
- img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
- img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
-
- img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
-
- for i, tab in enumerate(img2img_tabs):
- tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
-
- def copy_image(img):
- if isinstance(img, dict) and 'image' in img:
- return img['image']
-
- return img
-
- for button, name, elem in copy_image_buttons:
- button.click(
- fn=copy_image,
- inputs=[elem],
- outputs=[copy_image_destinations[name]],
- )
- button.click(
- fn=lambda: None,
- _js=f"switch_to_{name.replace(' ', '_')}",
- inputs=[],
- outputs=[],
- )
-
- with FormRow():
- resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
-
scripts.scripts_img2img.prepare_ui()
for category in ordered_ui_categories():
+ if category == "prompt":
+ toprow.create_inline_toprow_prompts()
+
+ if category == "image":
+ with gr.Tabs(elem_id="mode_img2img"):
+ img2img_selected_tab = gr.State(0)
+
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
+ add_copy_image_controls('img2img', init_img)
+
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
+ add_copy_image_controls('sketch', sketch)
+
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
+ add_copy_image_controls('inpaint', init_img_with_mask)
+
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
+ inpaint_color_sketch_orig = gr.State(None)
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
+
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
+
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
+
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
+ hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
+ gr.HTML(
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
+ f"{hidden}</p>"
+ )
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ with gr.Accordion("PNG info", open=False):
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
+
+ img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
+
+ for i, tab in enumerate(img2img_tabs):
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
+
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
+
+ return img
+
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js=f"switch_to_{name.replace(' ', '_')}",
+ inputs=[],
+ outputs=[],
+ )
+
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
@@ -693,12 +634,6 @@ def create_ui():
scale_by.release(**on_change_args)
button_update_resize_to.click(**on_change_args)
- # the code below is meant to update the resolution label after the image in the image selection UI has changed.
- # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
- # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
- for component in [init_img, sketch]:
- component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
-
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
@@ -756,20 +691,26 @@ def create_ui():
with gr.Column(scale=4):
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
- def select_img2img_tab(tab):
- return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
-
- for i, elem in enumerate(img2img_tabs):
- elem.select(
- fn=lambda tab=i: select_img2img_tab(tab),
- inputs=[],
- outputs=[inpaint_controls, mask_alpha],
- )
-
if category not in {"accordions"}:
scripts.scripts_img2img.setup_ui_for_section(category)
- img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
+ # the code below is meant to update the resolution label after the image in the image selection UI has changed.
+ # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
+ # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
+ for component in [init_img, sketch]:
+ component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate(img2img_tabs):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
@@ -970,71 +911,6 @@ def create_ui():
with gr.Column():
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
- with gr.Tab(label="Preprocess images", id="preprocess_images"):
- process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
- process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
-
- with gr.Row():
- process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
- process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
- process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
- process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
- process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
- process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
-
- with gr.Row(visible=False) as process_split_extra_row:
- process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
- process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
-
- with gr.Row(visible=False) as process_focal_crop_row:
- process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
- process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
- process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
- process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
- with gr.Column(visible=False) as process_multicrop_col:
- gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
- with gr.Row():
- process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
- process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
- with gr.Row():
- process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
- process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
- with gr.Row():
- process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
- process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
-
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
- run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
-
- process_split.change(
- fn=lambda show: gr_show(show),
- inputs=[process_split],
- outputs=[process_split_extra_row],
- )
-
- process_focal_crop.change(
- fn=lambda show: gr_show(show),
- inputs=[process_focal_crop],
- outputs=[process_focal_crop_row],
- )
-
- process_multicrop.change(
- fn=lambda show: gr_show(show),
- inputs=[process_multicrop],
- outputs=[process_multicrop_col],
- )
-
def get_textual_inversion_template_names():
return sorted(textual_inversion.textual_inversion_templates)
@@ -1135,42 +1011,6 @@ def create_ui():
]
)
- run_preprocess.click(
- fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
- _js="start_training_textual_inversion",
- inputs=[
- dummy_component,
- process_src,
- process_dst,
- process_width,
- process_height,
- preprocess_txt_action,
- process_keep_original_size,
- process_flip,
- process_split,
- process_caption,
- process_caption_deepbooru,
- process_split_threshold,
- process_overlap_ratio,
- process_focal_crop,
- process_focal_crop_face_weight,
- process_focal_crop_entropy_weight,
- process_focal_crop_edges_weight,
- process_focal_crop_debug,
- process_multicrop,
- process_multicrop_mindim,
- process_multicrop_maxdim,
- process_multicrop_minarea,
- process_multicrop_maxarea,
- process_multicrop_objective,
- process_multicrop_threshold,
- ],
- outputs=[
- ti_output,
- ti_outcome,
- ],
- )
-
train_embedding.click(
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
@@ -1244,13 +1084,8 @@ def create_ui():
outputs=[],
)
- interrupt_preprocessing.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
-
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
+ ui_settings_from_file = loadsave.ui_settings.copy()
settings = ui_settings.UiSettings()
settings.create_ui(loadsave, dummy_component)
@@ -1311,7 +1146,8 @@ def create_ui():
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
- loadsave.dump_defaults()
+ if ui_settings_from_file != loadsave.ui_settings:
+ loadsave.dump_defaults()
demo.ui_loadsave = loadsave
return demo
@@ -1366,7 +1202,7 @@ def setup_ui_api(app):
from fastapi.responses import PlainTextResponse
text = sysinfo.get()
- filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 84a7d7f2..fd32676f 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -8,10 +8,10 @@ import gradio as gr
import subprocess as sp
from modules import call_queue, shared
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext import image_from_url_text
import modules.images
from modules.ui_components import ToolButton
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-def create_output_panel(tabname, outdir):
+def create_output_panel(tabname, outdir, toprow=None):
def open_folder(f):
if not os.path.exists(f):
@@ -130,12 +130,15 @@ Requested path was: {f}
else:
sp.Popen(["xdg-open", path])
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
+ with gr.Column(elem_id=f"{tabname}_results"):
+ if toprow:
+ toprow.create_inline_toprow_image()
- generation_info = None
- with gr.Column():
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
+
+ generation_info = None
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index c0a73b57..dc1e34c8 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -65,7 +65,7 @@ def save_config_state(name):
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
- json.dump(current_config_state, f, indent=4)
+ json.dump(current_config_state, f, indent=4, ensure_ascii=False)
config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys())
@@ -335,6 +335,11 @@ def normalize_git_url(url):
return url
+def get_extension_dirname_from_url(url):
+ *parts, last_part = url.split('/')
+ return normalize_git_url(last_part)
+
+
def install_extension_from_url(dirname, url, branch_name=None):
check_access()
@@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None):
assert url, 'No URL specified'
if dirname is None or dirname == "":
- *parts, last_part = url.split('/')
- last_part = normalize_git_url(last_part)
-
- dirname = last_part
+ dirname = get_extension_dirname_from_url(url)
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
@@ -449,7 +451,8 @@ def get_date(info: dict, key):
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"]
- installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+ installed_extensions = {extension.name for extension in extensions.extensions}
+ installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None}
tags = available_extensions.get("tags", {})
tags_to_hide = set(hide_tags)
@@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
if url is None:
continue
- existing = installed_extension_urls.get(normalize_git_url(url), None)
+ existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls
extension_tags = extension_tags + ["installed"] if existing else extension_tags
if any(x for x in extension_tags if x in tags_to_hide):
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 59d6ecc6..790af135 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -10,7 +10,7 @@ import json
import html
from fastapi.exceptions import HTTPException
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext import image_from_url_text
from modules.ui_components import ToolButton
extra_pages = []
@@ -103,6 +103,7 @@ class ExtraNetworksPage:
self.name = title.lower()
self.id_page = self.name.replace(" ", "_")
self.card_page = shared.html("extra-networks-card.html")
+ self.allow_prompt = True
self.allow_negative_prompt = False
self.metadata = {}
self.items = {}
@@ -150,8 +151,13 @@ class ExtraNetworksPage:
continue
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
- while subdir.startswith("/"):
- subdir = subdir[1:]
+
+ if shared.opts.extra_networks_dir_button_function:
+ if not subdir.startswith("/"):
+ subdir = "/" + subdir
+ else:
+ while subdir.startswith("/"):
+ subdir = subdir[1:]
is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith("/"):
@@ -217,7 +223,10 @@ class ExtraNetworksPage:
onclick = item.get("onclick", None)
if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ if "negative_prompt" in item:
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ else:
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
@@ -278,6 +287,7 @@ class ExtraNetworksPage:
"date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0),
"name": pth.name.lower(),
+ "path": str(pth.parent).lower(),
}
def find_preview(self, path):
@@ -367,7 +377,10 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs = []
for page in ui.stored_extra_pages:
- with gr.Tab(page.title, id=page.id_page) as tab:
+ with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
+ with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
+ pass
+
elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
@@ -381,19 +394,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs.append(tab)
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
- dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
- button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False, tooltip="Invert sort order")
+ dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
+ button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+ tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
+
for tab in unrelated_tabs:
- tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
+ tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
+
+ for page, tab in zip(ui.stored_extra_pages, related_tabs):
+ allow_prompt = "true" if page.allow_prompt else "false"
+ allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
+
+ jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
+
+ tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
- for tab in related_tabs:
- tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
+ dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
def pages_html():
if not ui.pages_contents:
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index ca6c2607..1693e71f 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Checkpoints')
+ self.allow_prompt = False
+
def refresh(self):
shared.refresh_checkpoints()
def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
+ if checkpoint is None:
+ return
+
path, ext = os.path.splitext(checkpoint.filename)
return {
"name": checkpoint.name_for_extra,
@@ -30,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
+ # instantiate a list to protect against concurrent modification
names = list(sd_models.checkpoints_list)
for index, name in enumerate(names):
- yield self.create_item(name, index)
+ item = self.create_item(name, index)
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 4cedf085..c96c4fa3 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks()
def create_item(self, name, index=None, enable_filter=True):
- full_path = shared.hypernetworks[name]
+ full_path = shared.hypernetworks.get(name)
+ if full_path is None:
+ return
+
path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
@@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
- for index, name in enumerate(shared.hypernetworks):
- yield self.create_item(name, index)
+ # instantiate a list to protect against concurrent modification
+ names = list(shared.hypernetworks)
+ for index, name in enumerate(names):
+ item = self.create_item(name, index)
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 55ef0ea7..1b334fda 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
+ if embedding is None:
+ return
path, ext = os.path.splitext(embedding.filename)
return {
@@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
}
def list_items(self):
- for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
- yield self.create_item(name, index)
+ # instantiate a list to protect against concurrent modification
+ names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
+ for index, name in enumerate(names):
+ item = self.create_item(name, index)
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index bfec140c..87aeb6f3 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -5,7 +5,7 @@ import os.path
import gradio as gr
-from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks
+from modules import infotext, images, sysinfo, errors, ui_extra_networks
class UserMetadataEditor:
@@ -134,7 +134,7 @@ class UserMetadataEditor:
basename, ext = os.path.splitext(filename)
with open(basename + '.json', "w", encoding="utf8") as file:
- json.dump(metadata, file, indent=4)
+ json.dump(metadata, file, indent=4, ensure_ascii=False)
def save_user_metadata(self, name, desc, notes):
user_metadata = self.get_user_metadata(name)
@@ -181,7 +181,7 @@ class UserMetadataEditor:
index = len(gallery) - 1 if index >= len(gallery) else index
img_info = gallery[index if index >= 0 else 0]
- image = generation_parameters_copypaste.image_from_url_text(img_info)
+ image = infotext.image_from_url_text(img_info)
geninfo, items = images.read_info_from_image(image)
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index 0d368f8b..a86c368e 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -1,17 +1,12 @@
import os
import gradio as gr
-from modules import localization, shared, scripts
-from modules.paths import script_path, data_path, cwd
+from modules import localization, shared, scripts, util
+from modules.paths import script_path, data_path
def webpath(fn):
- if fn.startswith(cwd):
- web_path = os.path.relpath(fn, cwd)
- else:
- web_path = os.path.abspath(fn)
-
- return f'file={web_path}?{os.path.getmtime(fn)}'
+ return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}'
def javascript_html():
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
index eb20ff25..693ff75c 100644
--- a/modules/ui_loadsave.py
+++ b/modules/ui_loadsave.py
@@ -141,10 +141,10 @@ class UiLoadsave:
def write_to_file(self, current_ui_settings):
with open(self.filename, "w", encoding="utf8") as file:
- json.dump(current_ui_settings, file, indent=4)
+ json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
def dump_defaults(self):
- """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+ """saves default values to a file unless the file is present and there was an error loading default values at start"""
if self.error_loading and os.path.exists(self.filename):
return
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index 802e1ce7..b74a1532 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,9 +1,10 @@
import gradio as gr
-from modules import scripts, shared, ui_common, postprocessing, call_queue
-import modules.generation_parameters_copypaste as parameters_copypaste
+from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
+import modules.infotext as parameters_copypaste
def create_ui():
+ dummy_component = gr.Label(visible=False)
tab_index = gr.State(value=0)
with gr.Row(equal_height=False, variant='compact'):
@@ -20,11 +21,13 @@ def create_ui():
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
script_inputs = scripts.scripts_postproc.setup_ui()
with gr.Column():
+ toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras")
+ toprow.create_inline_toprow_image()
+ submit = toprow.submit
+
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
@@ -32,8 +35,10 @@ def create_ui():
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
submit.click(
- fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
+ _js="submit_extras",
inputs=[
+ dummy_component,
tab_index,
extras_image,
image_batch,
@@ -45,8 +50,9 @@ def create_ui():
outputs=[
result_images,
html_info_x,
- html_info,
- ]
+ html_log,
+ ],
+ show_progress=False,
)
parameters_copypaste.add_paste_fields("extras", extras_image, None)
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index 3bcf092f..0d74c23f 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -68,10 +68,10 @@ class UiPromptStyles:
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
with gr.Row():
- self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+ self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
with gr.Row():
- self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+ self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
with gr.Row():
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py
new file mode 100644
index 00000000..1abc9117
--- /dev/null
+++ b/modules/ui_toprow.py
@@ -0,0 +1,149 @@
+import gradio as gr
+
+from modules import shared, ui_prompt_styles
+import modules.images
+
+from modules.ui_components import ToolButton
+
+
+class Toprow:
+ """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
+
+ prompt = None
+ prompt_img = None
+ negative_prompt = None
+
+ button_interrogate = None
+ button_deepbooru = None
+
+ interrupt = None
+ skip = None
+ submit = None
+
+ paste = None
+ clear_prompt_button = None
+ apply_styles = None
+ restore_progress_button = None
+
+ token_counter = None
+ token_button = None
+ negative_token_counter = None
+ negative_token_button = None
+
+ ui_styles = None
+
+ submit_box = None
+
+ def __init__(self, is_img2img, is_compact=False, id_part=None):
+ if id_part is None:
+ id_part = "img2img" if is_img2img else "txt2img"
+
+ self.id_part = id_part
+ self.is_img2img = is_img2img
+ self.is_compact = is_compact
+
+ if not is_compact:
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ self.create_classic_toprow()
+ else:
+ self.create_submit_box()
+
+ def create_classic_toprow(self):
+ self.create_prompts()
+
+ with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"):
+ self.create_submit_box()
+
+ self.create_tools_row()
+
+ self.create_styles_ui()
+
+ def create_inline_toprow_prompts(self):
+ if not self.is_compact:
+ return
+
+ self.create_prompts()
+
+ with gr.Row(elem_classes=["toprow-compact-stylerow"]):
+ with gr.Column(elem_classes=["toprow-compact-tools"]):
+ self.create_tools_row()
+ with gr.Column():
+ self.create_styles_ui()
+
+ def create_inline_toprow_image(self):
+ if not self.is_compact:
+ return
+
+ self.submit_box.render()
+
+ def create_prompts(self):
+ with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
+ with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
+ self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
+
+ with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
+
+ self.prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[self.prompt_img],
+ outputs=[self.prompt, self.prompt_img],
+ show_progress=False,
+ )
+
+ def create_submit_box(self):
+ with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
+ self.submit_box = submit_box
+
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
+ self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
+ self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
+
+ self.skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
+ def interrupt_function():
+ if shared.state.job_count > 1 and shared.opts.interrupt_after_current:
+ shared.state.stop_generating()
+ else:
+ shared.state.interrupt()
+
+ self.interrupt.click(
+ fn=interrupt_function,
+ inputs=[],
+ outputs=[],
+ )
+
+ def create_tools_row(self):
+ with gr.Row(elem_id=f"{self.id_part}_tools"):
+ from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol
+
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt")
+ self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
+
+ if self.is_img2img:
+ self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate")
+ self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru")
+
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
+
+ self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
+ self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
+ self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
+ self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
+
+ self.clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[self.prompt, self.negative_prompt],
+ outputs=[self.prompt, self.negative_prompt],
+ )
+
+ def create_styles_ui(self):
+ self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
+ self.ui_styles.setup_apply_button(self.apply_styles)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index e682bbaa..3aee69db 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -57,6 +57,9 @@ class Upscaler:
dest_h = int((img.height * scale) // 8 * 8)
for _ in range(3):
+ if img.width >= dest_w and img.height >= dest_h:
+ break
+
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
@@ -64,9 +67,6 @@ class Upscaler:
if shape == (img.width, img.height):
break
- if img.width >= dest_w and img.height >= dest_h:
- break
-
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
@@ -98,6 +98,9 @@ class UpscalerData:
self.scale = scale
self.model = model
+ def __repr__(self):
+ return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
+
class UpscalerNone(Upscaler):
name = "None"
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
new file mode 100644
index 00000000..f5cb92d5
--- /dev/null
+++ b/modules/upscaler_utils.py
@@ -0,0 +1,140 @@
+import logging
+from typing import Callable
+
+import numpy as np
+import torch
+import tqdm
+from PIL import Image
+
+from modules import images, shared, torch_utils
+
+logger = logging.getLogger(__name__)
+
+
+def upscale_without_tiling(model, img: Image.Image):
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
+ img = torch.from_numpy(img).float()
+
+ param = torch_utils.get_param(model)
+ img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
+
+ with torch.no_grad():
+ output = model(img)
+
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ return Image.fromarray(output, 'RGB')
+
+
+def upscale_with_model(
+ model: Callable[[torch.Tensor], torch.Tensor],
+ img: Image.Image,
+ *,
+ tile_size: int,
+ tile_overlap: int = 0,
+ desc="tiled upscale",
+) -> Image.Image:
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img)
+ output = upscale_without_tiling(model, img)
+ logger.debug("=> %s", output)
+ return output
+
+ grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
+ newtiles = []
+
+ with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
+ for y, h, row in grid.tiles:
+ newrow = []
+ for x, w, tile in row:
+ logger.debug("Tile (%d, %d) %s...", x, y, tile)
+ output = upscale_without_tiling(model, tile)
+ scale_factor = output.width // tile.width
+ logger.debug("=> %s (scale factor %s)", output, scale_factor)
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ p.update(1)
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = images.Grid(
+ newtiles,
+ tile_w=grid.tile_w * scale_factor,
+ tile_h=grid.tile_h * scale_factor,
+ image_w=grid.image_w * scale_factor,
+ image_h=grid.image_h * scale_factor,
+ overlap=grid.overlap * scale_factor,
+ )
+ return images.combine_grid(newgrid)
+
+
+def tiled_upscale_2(
+ img,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ device,
+ desc="Tiled upscale",
+):
+ # Alternative implementation of `upscale_with_model` originally used by
+ # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
+ # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
+ # Pillow space without weighting.
+ b, c, h, w = img.size()
+ tile_size = min(tile_size, h, w)
+
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img.shape)
+ return model(img)
+
+ stride = tile_size - tile_overlap
+ h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
+ w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
+ result = torch.zeros(
+ b,
+ c,
+ h * scale,
+ w * scale,
+ device=device,
+ ).type_as(img)
+ weights = torch.zeros_like(result)
+ logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
+ with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
+ for h_idx in h_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ for w_idx in w_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ in_patch = img[
+ ...,
+ h_idx : h_idx + tile_size,
+ w_idx : w_idx + tile_size,
+ ]
+ out_patch = model(in_patch)
+
+ result[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch)
+
+ out_patch_mask = torch.ones_like(out_patch)
+
+ weights[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch_mask)
+
+ pbar.update(1)
+
+ output = result.div_(weights)
+
+ return output
diff --git a/modules/util.py b/modules/util.py
index 60afc067..4861bcb0 100644
--- a/modules/util.py
+++ b/modules/util.py
@@ -2,7 +2,7 @@ import os
import re
from modules import shared
-from modules.paths_internal import script_path
+from modules.paths_internal import script_path, cwd
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
@@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs):
return
print(*args, **kwargs)
+
+
+def truncate_path(target_path, base_path=cwd):
+ abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)
+ try:
+ if os.path.commonpath([abs_target, abs_base]) == abs_base:
+ return os.path.relpath(abs_target, abs_base)
+ except ValueError:
+ pass
+ return abs_target
diff --git a/modules/xlmr.py b/modules/xlmr.py
index a407a3ca..319771b7 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules import torch_utils
+
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index a727e865..f6055504 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -4,6 +4,8 @@ import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules import torch_utils
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -68,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
new file mode 100644
index 00000000..f7687a66
--- /dev/null
+++ b/modules/xpu_specific.py
@@ -0,0 +1,124 @@
+from modules import shared
+from modules.sd_hijack_utils import CondFunc
+
+has_ipex = False
+try:
+ import torch
+ import intel_extension_for_pytorch as ipex # noqa: F401
+ has_ipex = True
+except Exception:
+ pass
+
+
+def check_for_xpu():
+ return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
+
+
+def get_xpu_device_string():
+ if shared.cmd_opts.device_id is not None:
+ return f"xpu:{shared.cmd_opts.device_id}"
+ return "xpu"
+
+
+def torch_xpu_gc():
+ with torch.xpu.device(get_xpu_device_string()):
+ torch.xpu.empty_cache()
+
+
+has_xpu = check_for_xpu()
+
+
+# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
+# Here we implement a slicing algorithm to split large batch size into smaller chunks,
+# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
+# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
+# which is the best trade-off between VRAM usage and performance.
+ARC_SINGLE_ALLOCATION_LIMIT = {}
+orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
+def torch_xpu_scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
+):
+ # cast to same dtype first
+ key = key.to(query.dtype)
+ value = value.to(query.dtype)
+
+ N = query.shape[:-2] # Batch size
+ L = query.size(-2) # Target sequence length
+ E = query.size(-1) # Embedding dimension of the query and key
+ S = key.size(-2) # Source sequence length
+ Ev = value.size(-1) # Embedding dimension of the value
+
+ total_batch_size = torch.numel(torch.empty(N))
+ device_id = query.device.index
+ if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
+ ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
+ batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
+
+ if total_batch_size <= batch_size_limit:
+ return orig_sdp_attn_func(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+
+ query = torch.reshape(query, (-1, L, E))
+ key = torch.reshape(key, (-1, S, E))
+ value = torch.reshape(value, (-1, S, Ev))
+ if attn_mask is not None:
+ attn_mask = attn_mask.view(-1, L, S)
+ chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
+ outputs = []
+ for i in range(chunk_count):
+ attn_mask_chunk = (
+ None
+ if attn_mask is None
+ else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
+ )
+ chunk_output = orig_sdp_attn_func(
+ query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ attn_mask_chunk,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+ outputs.append(chunk_output)
+ result = torch.cat(outputs, dim=0)
+ return torch.reshape(result, (*N, L, Ev))
+
+
+if has_xpu:
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
+ CondFunc('torch.Generator',
+ lambda orig_func, device=None: torch.xpu.Generator(device),
+ lambda orig_func, device=None: device is not None and device.type == "xpu")
+
+ # W/A for some OPs that could not handle different input dtypes
+ CondFunc('torch.nn.functional.layer_norm',
+ lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
+ orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
+ lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
+ weight is not None and input.dtype != weight.data.dtype)
+ CondFunc('torch.nn.modules.GroupNorm.forward',
+ lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.nn.modules.linear.Linear.forward',
+ lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.nn.modules.conv.Conv2d.forward',
+ lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
+ lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.bmm',
+ lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
+ lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
+ CondFunc('torch.cat',
+ lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
+ lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
+ CondFunc('torch.nn.functional.scaled_dot_product_attention',
+ lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
+ lambda orig_func, query, *args, **kwargs: query.is_xpu)
diff --git a/pyproject.toml b/pyproject.toml
index 80541a8f..d03036e7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,6 +16,7 @@ exclude = [
ignore = [
"E501", # Line too long
+ "E721", # Do not compare types, use `isinstance`
"E731", # Do not assign a `lambda` expression, use a `def`
"I001", # Import block is un-sorted or un-formatted
diff --git a/requirements.txt b/requirements.txt
index 80b43845..731a1be7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,12 +2,11 @@ GitPython
Pillow
accelerate
-basicsr
blendmodes
clean-fid
einops
+facexlib
fastapi>=0.90.1
-gfpgan
gradio==3.41.2
inflection
jsonmerge
@@ -20,13 +19,11 @@ open-clip-torch
piexif
psutil
pytorch_lightning
-realesrgan
requests
resize-right
safetensors
scikit-image>=0.19
-timm
tomesd
torch
torchdiffeq
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 7d27f2be..2a922f28 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,31 +1,30 @@
GitPython==3.1.32
Pillow==9.5.0
accelerate==0.21.0
-basicsr==1.4.2
blendmodes==2022
clean-fid==0.1.35
einops==0.4.1
+facexlib==0.3.0
fastapi==0.94.0
-gfpgan==1.3.8
gradio==3.41.2
httpcore==0.15
inflection==0.5.1
jsonmerge==1.8.0
kornia==0.6.7
lark==1.1.2
-numpy==1.23.5
+numpy==1.26.2
omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3
psutil==5.9.5
pytorch_lightning==1.9.4
-realesrgan==0.3.0
resize-right==0.0.2
safetensors==0.3.1
scikit-image==0.21.0
-timm==0.9.2
+spandrel==0.1.6
tomesd==0.1.3
torch
torchdiffeq==0.2.3
torchsde==0.2.6
transformers==4.30.2
+httpx==0.24.1
diff --git a/script.js b/script.js
index 5f6ee354..be1bc317 100644
--- a/script.js
+++ b/script.js
@@ -121,26 +121,56 @@ document.addEventListener("DOMContentLoaded", function() {
});
/**
- * Add a ctrl+enter as a shortcut to start a generation
+ * Add keyboard shortcuts:
+ * Ctrl+Enter to start/restart a generation
+ * Alt/Option+Enter to skip a generation
+ * Esc to interrupt a generation
*/
document.addEventListener('keydown', function(e) {
const isEnter = e.key === 'Enter' || e.keyCode === 13;
- const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
+ const isCtrlKey = e.metaKey || e.ctrlKey;
+ const isAltKey = e.altKey;
+ const isEsc = e.key === 'Escape';
- const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
+ const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
+ const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]');
- if (isEnter && isModifierKey) {
+ if (isCtrlKey && isEnter) {
if (interruptButton.style.display === 'block') {
interruptButton.click();
- setTimeout(function() {
- generateButton.click();
- }, 500);
+ const callback = (mutationList) => {
+ for (const mutation of mutationList) {
+ if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
+ if (interruptButton.style.display === 'none') {
+ generateButton.click();
+ observer.disconnect();
+ }
+ }
+ }
+ };
+ const observer = new MutationObserver(callback);
+ observer.observe(interruptButton, {attributes: true});
} else {
generateButton.click();
}
e.preventDefault();
}
+
+ if (isAltKey && isEnter) {
+ skipButton.click();
+ e.preventDefault();
+ }
+
+ if (isEsc) {
+ const globalPopup = document.querySelector('.global-popup');
+ const lightboxModal = document.querySelector('#lightboxModal');
+ if (!globalPopup || globalPopup.style.display === 'none') {
+ if (document.activeElement === lightboxModal) return;
+ interruptButton.click();
+ e.preventDefault();
+ }
+ }
});
/**
diff --git a/scripts/loopback.py b/scripts/loopback.py
index 2d5feaf9..800ee882 100644
--- a/scripts/loopback.py
+++ b/scripts/loopback.py
@@ -95,7 +95,7 @@ class Script(scripts.Script):
processed = processing.process_images(p)
# Generation cancelled.
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
if initial_seed is None:
@@ -122,8 +122,8 @@ class Script(scripts.Script):
p.inpainting_fill = original_inpainting_fill
- if state.interrupted:
- break
+ if state.interrupted or state.stopping_generation:
+ break
if len(history) > 1:
grid = images.image_grid(history, rows=1)
diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py
new file mode 100644
index 00000000..5592a898
--- /dev/null
+++ b/scripts/postprocessing_caption.py
@@ -0,0 +1,30 @@
+from modules import scripts_postprocessing, ui_components, deepbooru, shared
+import gradio as gr
+
+
+class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
+ name = "Caption"
+ order = 4040
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Caption") as enable:
+ option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)
+
+ return {
+ "enable": enable,
+ "option": option,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
+ if not enable:
+ return
+
+ captions = [pp.caption]
+
+ if "Deepbooru" in option:
+ captions.append(deepbooru.model.tag(pp.image))
+
+ if "BLIP" in option:
+ captions.append(shared.interrogator.interrogate(pp.image.convert("RGB")))
+
+ pp.caption = ", ".join([x for x in captions if x])
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
index a7d80d40..e1e156dd 100644
--- a/scripts/postprocessing_codeformer.py
+++ b/scripts/postprocessing_codeformer.py
@@ -1,28 +1,28 @@
from PIL import Image
import numpy as np
-from modules import scripts_postprocessing, codeformer_model
+from modules import scripts_postprocessing, codeformer_model, ui_components
import gradio as gr
-from modules.ui_components import FormRow
-
class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
name = "CodeFormer"
order = 3000
def ui(self):
- with FormRow():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
+ with ui_components.InputAccordion(False, label="CodeFormer") as enable:
+ with gr.Row():
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
return {
+ "enable": enable,
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
}
- def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
- if codeformer_visibility == 0:
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight):
+ if codeformer_visibility == 0 or not enable:
return
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py
new file mode 100644
index 00000000..b673003b
--- /dev/null
+++ b/scripts/postprocessing_create_flipped_copies.py
@@ -0,0 +1,32 @@
+from PIL import ImageOps, Image
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
+ name = "Create flipped copies"
+ order = 4030
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Create flipped copies") as enable:
+ with gr.Row():
+ option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False)
+
+ return {
+ "enable": enable,
+ "option": option,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
+ if not enable:
+ return
+
+ if "Horizontal" in option:
+ pp.extra_images.append(ImageOps.mirror(pp.image))
+
+ if "Vertical" in option:
+ pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))
+
+ if "Both" in option:
+ pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))
diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py
new file mode 100644
index 00000000..cff1dbc5
--- /dev/null
+++ b/scripts/postprocessing_focal_crop.py
@@ -0,0 +1,54 @@
+
+from modules import scripts_postprocessing, ui_components, errors
+import gradio as gr
+
+from modules.textual_inversion import autocrop
+
+
+class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
+ name = "Auto focal point crop"
+ order = 4010
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
+ face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
+ entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
+ edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
+ debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
+
+ return {
+ "enable": enable,
+ "face_weight": face_weight,
+ "entropy_weight": entropy_weight,
+ "edges_weight": edges_weight,
+ "debug": debug,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
+ if not enable:
+ return
+
+ if not pp.shared.target_width or not pp.shared.target_height:
+ return
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models()
+ except Exception:
+ errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)
+
+ autocrop_settings = autocrop.Settings(
+ crop_width=pp.shared.target_width,
+ crop_height=pp.shared.target_height,
+ face_points_weight=face_weight,
+ entropy_points_weight=entropy_weight,
+ corner_points_weight=edges_weight,
+ annotate_image=debug,
+ dnn_model_path=dnn_model_path,
+ )
+
+ result, *others = autocrop.crop_image(pp.image, autocrop_settings)
+
+ pp.image = result
+ pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]
+
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
index d854f3f7..6e756605 100644
--- a/scripts/postprocessing_gfpgan.py
+++ b/scripts/postprocessing_gfpgan.py
@@ -1,26 +1,25 @@
from PIL import Image
import numpy as np
-from modules import scripts_postprocessing, gfpgan_model
+from modules import scripts_postprocessing, gfpgan_model, ui_components
import gradio as gr
-from modules.ui_components import FormRow
-
class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
name = "GFPGAN"
order = 2000
def ui(self):
- with FormRow():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
+ with ui_components.InputAccordion(False, label="GFPGAN") as enable:
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility")
return {
+ "enable": enable,
"gfpgan_visibility": gfpgan_visibility,
}
- def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
- if gfpgan_visibility == 0:
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility):
+ if gfpgan_visibility == 0 or not enable:
return
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
diff --git a/scripts/postprocessing_split_oversized.py b/scripts/postprocessing_split_oversized.py
new file mode 100644
index 00000000..c4a03160
--- /dev/null
+++ b/scripts/postprocessing_split_oversized.py
@@ -0,0 +1,71 @@
+import math
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
+
+class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):
+ name = "Split oversized images"
+ order = 4000
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Split oversized images") as enable:
+ with gr.Row():
+ split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold")
+ overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio")
+
+ return {
+ "enable": enable,
+ "split_threshold": split_threshold,
+ "overlap_ratio": overlap_ratio,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):
+ if not enable:
+ return
+
+ width = pp.shared.target_width
+ height = pp.shared.target_height
+
+ if not width or not height:
+ return
+
+ if pp.image.height > pp.image.width:
+ ratio = (pp.image.width * height) / (pp.image.height * width)
+ inverse_xy = False
+ else:
+ ratio = (pp.image.height * width) / (pp.image.width * height)
+ inverse_xy = True
+
+ if ratio >= 1.0 and ratio > split_threshold:
+ return
+
+ result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)
+
+ pp.image = result
+ pp.extra_images = [pp.create_copy(x) for x in others]
+
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index eb42a29e..ed709688 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
return image
+ def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
+ if upscale_mode == 1:
+ pp.shared.target_width = upscale_to_width
+ pp.shared.target_height = upscale_to_height
+ else:
+ pp.shared.target_width = int(pp.image.width * upscale_by)
+ pp.shared.target_height = int(pp.image.height * upscale_by)
+
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscaler_1_name == "None":
upscaler_1_name = None
@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
"upscaler_name": upscaler_name,
}
+ def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
+ pp.shared.target_width = int(pp.image.width * upscale_by)
+ pp.shared.target_height = int(pp.image.height * upscale_by)
+
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None":
return
diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py
new file mode 100644
index 00000000..7e674989
--- /dev/null
+++ b/scripts/processing_autosized_crop.py
@@ -0,0 +1,64 @@
+from PIL import Image
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))
+ wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],
+ default=None
+ )
+ return wh and center_crop(image, *wh)
+
+
+class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
+ name = "Auto-sized crop"
+ order = 4020
+
+ def ui(self):
+ with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+ with gr.Row():
+ mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim")
+ maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim")
+ with gr.Row():
+ minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea")
+ maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea")
+ with gr.Row():
+ objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective")
+ threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold")
+
+ return {
+ "enable": enable,
+ "mindim": mindim,
+ "maxdim": maxdim,
+ "minarea": minarea,
+ "maxarea": maxarea,
+ "objective": objective,
+ "threshold": threshold,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):
+ if not enable:
+ return
+
+ cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)
+ if cropped is not None:
+ pp.image = cropped
+ else:
+ print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)")
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index ca73b2a5..a4a2f24d 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -114,6 +114,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
+ prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
@@ -124,9 +125,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
- return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
+ return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
- def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
+ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True
@@ -167,6 +168,18 @@ class Script(scripts.Script):
else:
setattr(copy_p, k, v)
+ if args.get("prompt") and p.prompt:
+ if prompt_position == "start":
+ copy_p.prompt = args.get("prompt") + " " + p.prompt
+ else:
+ copy_p.prompt = p.prompt + " " + args.get("prompt")
+
+ if args.get("negative_prompt") and p.negative_prompt:
+ if prompt_position == "start":
+ copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
+ else:
+ copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
+
proc = process_images(copy_p)
images += proc.images
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 0dc255bc..2f385ebf 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -270,6 +270,7 @@ axis_options = [
AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),
AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')),
AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]),
+ AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]),
]
@@ -437,13 +438,16 @@ class Script(scripts.Script):
with gr.Column():
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ with gr.Row():
+ vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.")
+ vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.")
+ vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.")
with gr.Column():
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
+ csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
- with gr.Column():
- csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
@@ -475,6 +479,8 @@ class Script(scripts.Script):
fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
+ axis_type = axis_type or 0 # if axle type is None set to 0
+
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
@@ -522,9 +528,11 @@ class Script(scripts.Script):
(z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
)
- return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
+ return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode]
+
+ def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode):
+ x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0
- def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -688,7 +696,7 @@ class Script(scripts.Script):
grid_infotext = [None] * (1 + len(zs))
def cell(x, y, z, ix, iy, iz):
- if shared.state.interrupted:
+ if shared.state.interrupted or state.stopping_generation:
return Processed(p, [], p.seed, "")
pc = copy(p)
@@ -697,6 +705,16 @@ class Script(scripts.Script):
y_opt.apply(pc, y, ys)
z_opt.apply(pc, z, zs)
+ xdim = len(xs) if vary_seeds_x else 1
+ ydim = len(ys) if vary_seeds_y else 1
+
+ if vary_seeds_x:
+ pc.seed += ix
+ if vary_seeds_y:
+ pc.seed += iy * xdim
+ if vary_seeds_z:
+ pc.seed += iz * xdim * ydim
+
try:
res = process_images(pc)
except Exception as e:
diff --git a/style.css b/style.css
index 115626cd..6d4c8a0d 100644
--- a/style.css
+++ b/style.css
@@ -204,6 +204,11 @@ div.block.gradio-accordion {
padding: 8px 8px;
}
+input[type="checkbox"].input-accordion-checkbox{
+ vertical-align: sub;
+ margin-right: 0.5em;
+}
+
/* txt2img/img2img specific */
@@ -291,6 +296,13 @@ div.block.gradio-accordion {
min-height: 4.5em;
}
+#txt2img_generate, #img2img_generate {
+ min-height: 4.5em;
+}
+.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate {
+ min-height: 3em;
+}
+
@media screen and (min-width: 2500px) {
#txt2img_gallery, #img2img_gallery {
min-height: 768px;
@@ -398,6 +410,15 @@ div#extras_scale_to_tab div.form{
min-width: 0.5em;
}
+div.toprow-compact-stylerow{
+ margin: 0.5em 0;
+}
+
+div.toprow-compact-tools{
+ min-width: fit-content !important;
+ max-width: fit-content;
+}
+
/* settings */
#quicksettings {
align-items: end;
@@ -441,6 +462,15 @@ div#extras_scale_to_tab div.form{
padding: 4px;
}
+#settings > div.tab-nav .settings-category{
+ display: block;
+ margin: 1em 0 0.25em 0;
+ font-weight: bold;
+ text-decoration: underline;
+ cursor: default;
+ user-select: none;
+}
+
#settings_result{
height: 1.4em;
margin: 0 1.2em;
@@ -520,7 +550,8 @@ table.popup-table .link{
height: 20px;
background: #b4c0cc;
border-radius: 3px !important;
- top: -20px;
+ top: -14px;
+ left: 0px;
width: 100%;
}
@@ -615,6 +646,8 @@ table.popup-table .link{
margin: auto;
padding: 2em;
z-index: 1001;
+ max-height: 90%;
+ max-width: 90%;
}
/* fullpage image viewer */
@@ -646,7 +679,7 @@ table.popup-table .link{
transition: 0.2s ease background-color;
}
.modalControls:hover {
- background-color:rgba(0,0,0,0.9);
+ background-color:rgba(0,0,0, var(--sd-webui-modal-lightbox-toolbar-opacity));
}
.modalClose {
margin-left: auto;
@@ -716,6 +749,22 @@ table.popup-table .link{
display: none;
}
+@media (pointer: fine) {
+ .modalPrev:hover,
+ .modalNext:hover,
+ .modalControls:hover ~ .modalPrev,
+ .modalControls:hover ~ .modalNext,
+ .modalControls:hover .cursor {
+ opacity: 1;
+ }
+
+ .modalPrev,
+ .modalNext,
+ .modalControls .cursor {
+ opacity: var(--sd-webui-modal-lightbox-icon-opacity);
+ }
+}
+
/* context menu (ie for the generate button) */
#context-menu{
@@ -818,6 +867,18 @@ footer {
/* extra networks UI */
+.extra-page > div.gap{
+ gap: 0;
+}
+
+.extra-page-prompts{
+ margin-bottom: 0;
+}
+
+.extra-page-prompts.extra-page-prompts-active{
+ margin-bottom: 1em;
+}
+
.extra-network-cards{
height: calc(100vh - 24rem);
overflow: clip scroll;
diff --git a/test/conftest.py b/test/conftest.py
index 31a5d9ea..e4fc5678 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,10 +1,16 @@
+import base64
import os
import pytest
-import base64
-
test_files_path = os.path.dirname(__file__) + "/test_files"
+test_outputs_path = os.path.dirname(__file__) + "/test_outputs"
+
+
+def pytest_configure(config):
+ # We don't want to fail on Py.test command line arguments being
+ # parsed by webui:
+ os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
def file_to_base64(filename):
@@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str:
@pytest.fixture(scope="session") # session so we don't read this over and over
def mask_basic_image_base64() -> str:
return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
+
+
+@pytest.fixture(scope="session")
+def initialize() -> None:
+ import webui # noqa: F401
diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py
new file mode 100644
index 00000000..7760d51b
--- /dev/null
+++ b/test/test_face_restorers.py
@@ -0,0 +1,29 @@
+import os
+from test.conftest import test_files_path, test_outputs_path
+
+import numpy as np
+import pytest
+from PIL import Image
+
+
+@pytest.mark.usefixtures("initialize")
+@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"])
+def test_face_restorers(restorer_name):
+ from modules import shared
+
+ if restorer_name == "gfpgan":
+ from modules import gfpgan_model
+ gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)
+ restorer = gfpgan_model.gfpgan_fix_faces
+ elif restorer_name == "codeformer":
+ from modules import codeformer_model
+ codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)
+ restorer = codeformer_model.codeformer.restore
+ else:
+ raise NotImplementedError("...")
+ img = Image.open(os.path.join(test_files_path, "two-faces.jpg"))
+ np_img = np.array(img, dtype=np.uint8)
+ fixed_image = restorer(np_img)
+ assert fixed_image.shape == np_img.shape
+ assert not np.allclose(fixed_image, np_img) # should have visibly changed
+ Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png"))
diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg
new file mode 100644
index 00000000..c9d1b010
--- /dev/null
+++ b/test/test_files/two-faces.jpg
Binary files differ
diff --git a/test/test_outputs/.gitkeep b/test/test_outputs/.gitkeep
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/test_outputs/.gitkeep
diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py
new file mode 100644
index 00000000..23ccb93a
--- /dev/null
+++ b/test/test_torch_utils.py
@@ -0,0 +1,19 @@
+import types
+
+import pytest
+import torch
+
+from modules import torch_utils
+
+
+@pytest.mark.parametrize("wrapped", [True, False])
+def test_get_param(wrapped):
+ mod = torch.nn.Linear(1, 1)
+ cpu = torch.device("cpu")
+ mod.to(dtype=torch.float16, device=cpu)
+ if wrapped:
+ # more or less how spandrel wraps a thing
+ mod = types.SimpleNamespace(model=mod)
+ p = torch_utils.get_param(mod)
+ assert p.dtype == torch.float16
+ assert p.device == cpu
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
index 24bc5c42..db7e8b1a 100644
--- a/webui-macos-env.sh
+++ b/webui-macos-env.sh
@@ -11,7 +11,7 @@ fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
-export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
+export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0"
export PYTORCH_ENABLE_MPS_FALLBACK=1
####################################################################
diff --git a/webui.py b/webui.py
index 9ed20b30..2c417168 100644
--- a/webui.py
+++ b/webui.py
@@ -39,7 +39,7 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
- server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
+ server_name=initialize_util.gradio_server_name(),
port=cmd_opts.port if cmd_opts.port else 7861,
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
)
diff --git a/webui.sh b/webui.sh
index 08911469..38258ef6 100755
--- a/webui.sh
+++ b/webui.sh
@@ -89,7 +89,7 @@ delimiter="################################################################"
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
-printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
+printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m"
printf "\n%s\n" "${delimiter}"
# Do not run as root
@@ -133,7 +133,7 @@ case "$gpu_info" in
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]]
then
# Navi users will still use torch 1.13 because 2.0 does not seem to work.
- export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
+ export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
else
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
exit 1
@@ -143,8 +143,7 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
- export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
- # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
+ export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7"
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
@@ -223,13 +222,30 @@ fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
- TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
- if [[ ! -z "${TCMALLOC}" ]]; then
- echo "Using TCMalloc: ${TCMALLOC}"
- export LD_PRELOAD="${TCMALLOC}"
- else
- printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
- fi
+ # Define Tcmalloc Libs arrays
+ TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d")
+
+ # Traversal array
+ for lib in "${TCMALLOC_LIBS[@]}"
+ do
+ #Determine which type of tcmalloc library the library supports
+ TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)"
+ TC_INFO=(${TCMALLOC//=>/})
+ if [[ ! -z "${TC_INFO}" ]]; then
+ echo "Using TCMalloc: ${TC_INFO}"
+ #Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create
+ if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then
+ echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}"
+ export LD_PRELOAD="${TC_INFO}"
+ break
+ else
+ echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error"
+ fi
+ else
+ printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
+ fi
+ done
+
fi
}