aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py53
1 files changed, 33 insertions, 20 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 930d0bee..841402e8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,22 +1,22 @@
import collections
import os.path
import sys
-import gc
import threading
import torch
import re
import safetensors.torch
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
+import numpy as np
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -49,11 +49,12 @@ class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
abspath = os.path.abspath(filename)
+ abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
+ name = abspath.replace(abs_ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
@@ -129,9 +130,12 @@ except Exception:
def setup_model():
+ """called once at startup to do various one-time tasks related to SD models"""
+
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
+ patch_given_betas()
def checkpoint_tiles(use_short=False):
@@ -309,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
+ # move to end as latest
+ checkpoints_loaded.move_to_end(checkpoint_info)
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@@ -346,16 +352,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
-
+ model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
+ if model.is_ssd:
+ sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
@@ -453,6 +462,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+def patch_given_betas():
+ import ldm.models.diffusion.ddpm
+
+ def patched_register_schedule(*args, **kwargs):
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+ if isinstance(args[1], ListConfig):
+ args = (args[0], np.array(args[1]), *args[2:])
+
+ original_register_schedule(*args, **kwargs)
+
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
+
+
def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):
@@ -777,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- timer = Timer()
-
- if model_data.sd_model:
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- model_data.sd_model = None
- sd_model = None
- gc.collect()
- devices.torch_gc()
-
- print(f"Unloaded weights {timer.summary()}.")
+ send_model_to_cpu(sd_model or shared.sd_model)
return sd_model