aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-12 23:52:43 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-07-12 23:52:43 +0300
commitda464a3fb39ecc6ea7b22fe87271194480d8501c (patch)
treefd67d92762d0490d9d4784aaae3f2a3c2f31c6ca /modules/sd_models.py
parentaf081211ee93622473ee575de30fed2fd8263c09 (diff)
SDXL support
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8d639583..e4aae597 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -411,6 +411,7 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
class SdModelData:
@@ -445,6 +446,15 @@ class SdModelData:
model_data = SdModelData()
+def get_empty_cond(sd_model):
+ if hasattr(sd_model, 'conditioner'):
+ d = sd_model.get_learned_conditioning([""])
+ return d['crossattn']
+ else:
+ return sd_model.cond_stage_model([""])
+
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict
timer.record("find config")
@@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
- sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")