aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-18 23:23:09 +0300
committerGitHub <noreply@github.com>2023-01-18 23:23:09 +0300
commit43fd6eaab8231f68c314b6d4fa41e2ca19582310 (patch)
tree3d054acfbfbd01c858912ce88db68b538b227e51
parentb186d44dcd0df9d127a663b297334a5bd8258b58 (diff)
parentd906f87043d809e6d4d8de3c9926e184169b330f (diff)
Merge pull request #6851 from ddPn08/master
Add `--vae-dir` argument
-rw-r--r--modules/sd_vae.py7
-rw-r--r--modules/shared.py1
2 files changed, 8 insertions, 0 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index b2af2ce7..da1bf15c 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -72,6 +72,13 @@ def refresh_vae_list():
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
]
+ if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
+ ]
+
candidates = []
for path in paths:
candidates += glob.iglob(path, recursive=True)
diff --git a/modules/shared.py b/modules/shared.py
index ddb97f99..77e5e91c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -26,6 +26,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")