aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-09-30 09:01:12 +0300
committerGitHub <noreply@github.com>2023-09-30 09:01:12 +0300
commite309583f295f5da18170d1428d5ebbef12d3a207 (patch)
tree85a01e80d205bf82dd5fd1fc4f7166131266cb43 /modules/sd_models.py
parent7ce1f3a142574db55d2959054ece2bcf472d8970 (diff)
parentd9d94141dcfc1a84e98370bc137ffd888509b65e (diff)
Merge pull request #13276 from woweenie/patch-1
patch DDPM.register_betas so that users can put given_betas in model yaml
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eedb38c6..e3755253 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,7 +7,7 @@ import threading
import torch
import re
import safetensors.torch
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
@@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer
import tomesd
+import numpy as np
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -132,6 +133,7 @@ def setup_model():
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
+ patch_given_betas()
def checkpoint_tiles(use_short=False):
@@ -455,6 +457,17 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+def patch_given_betas():
+ original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
+ def patched_register_schedule(*args, **kwargs):
+ if args[1] is not None and isinstance(args[1], ListConfig):
+ modified_args = list(args) # Convert args tuple to a list
+ modified_args[1] = np.array(args[1]) # Modify the desired element
+ args = tuple(modified_args) # Convert the list back to a tuple
+ original_register_schedule(*args, **kwargs)
+ ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule
+
+
def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):