aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-09-30 09:11:31 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-09-30 09:11:31 +0300
commit87b50397a6da273fe0160016a209e4eb0cbf4a89 (patch)
tree1b41831d022258e7a6c2d40aa0921b415616bba0 /modules/sd_models.py
parente309583f295f5da18170d1428d5ebbef12d3a207 (diff)
add missing import, simplify code, use patches module for #13276
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e3755253..5ef7aa13 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
import numpy as np
@@ -130,6 +130,8 @@ except Exception:
def setup_model():
+ """called once at startup to do various one-time tasks related to SD models"""
+
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
@@ -458,14 +460,17 @@ def enable_midas_autodownload():
def patch_given_betas():
- original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
+ import ldm.models.diffusion.ddpm
+
def patched_register_schedule(*args, **kwargs):
- if args[1] is not None and isinstance(args[1], ListConfig):
- modified_args = list(args) # Convert args tuple to a list
- modified_args[1] = np.array(args[1]) # Modify the desired element
- args = tuple(modified_args) # Convert the list back to a tuple
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+ if isinstance(args[1], ListConfig):
+ args = (args[0], np.array(args[1]), *args[2:])
+
original_register_schedule(*args, **kwargs)
- ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule
+
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config):