From 8e0d16e746759b1f9b4bf1b5abfc30f3d985415e Mon Sep 17 00:00:00 2001 From: Zhang Hua Date: Sat, 11 Mar 2023 12:22:59 +0800 Subject: modules/sd_vae_approx.py: fix VAE-approx path --- modules/sd_vae_approx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_vae_approx.py') diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 0027343a..e2f00468 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -35,8 +35,11 @@ def model(): global sd_vae_approx_model if sd_vae_approx_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") sd_vae_approx_model = VAEApprox() - sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) + if not os.path.exists(model_path): + model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") + sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) sd_vae_approx_model.eval() sd_vae_approx_model.to(devices.device, devices.dtype) -- cgit v1.2.1