aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-14 09:16:01 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-07-14 09:16:01 +0300
commit6d8dcdefa07d5f8f7e528046b0facdcc51185e60 (patch)
treec5298147907e890dc5e3094a9713f8e9a67c889e
parentdc3906185656dae75fcefe96625b1dcd0d31579c (diff)
initial SDXL refiner support
-rw-r--r--modules/sd_hijack.py18
-rw-r--r--modules/sd_models.py3
-rw-r--r--modules/sd_models_config.py3
-rw-r--r--modules/sd_models_xl.py57
-rw-r--r--modules/shared.py9
5 files changed, 71 insertions, 19 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 647cdfbe..2b274c18 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -180,21 +180,29 @@ class StableDiffusionModelHijack:
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
+ text_cond_models = []
+
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
- m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
- conditioner.embedders[i] = m.cond_stage_model
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenCLIPEmbedder':
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
- conditioner.embedders[i] = m.cond_stage_model
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+
+ if len(text_cond_models) == 1:
+ m.cond_stage_model = text_cond_models[0]
+ else:
+ m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 07702175..267f4d8e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -414,6 +414,7 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
+sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
@@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict
+ clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict])
timer.record("find config")
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 04c09ab0..8266fa39 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
+config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename):
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
return config_sdxl
+ if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
+ return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index a7240dc0..01320c7a 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
width = getattr(self, 'target_width', 1024)
height = getattr(self, 'target_height', 1024)
+ is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
+ aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
+
+ devices_args = dict(device=devices.device, dtype=devices.dtype)
sdxl_conds = {
"txt": batch,
- "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
- "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype),
- "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
+ "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
+ "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
+ "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
+ "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
}
- force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch)
+ force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
return c
@@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
return x
+
+sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
+sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
+sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
+
+
+def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
+ res = []
+
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
+ encoded = embedder.encode_embedding_init_text(init_text, nvpt)
+ res.append(encoded)
+
+ return torch.cat(res, dim=1)
+
+
+def process_texts(self, texts):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
+ return embedder.process_texts(texts)
+
+
+def get_target_prompt_token_count(self, token_count):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
+ return embedder.get_target_prompt_token_count(token_count)
+
+
+# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
+sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
+sgm.modules.GeneralConditioner.process_texts = process_texts
+sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
+
+
def extend_sdxl(model):
+ """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
+
dtype = next(model.model.diffusion_model.parameters()).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
-
- model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0]
- model.cond_stage_key = model.cond_stage_model.input_key
+ model.cond_stage_key = 'txt'
+ # model.cond_stage_model will be set in sd_hijack
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.conditioner.wrapped = torch.nn.Module()
-sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
-sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
-sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
-
sgm.modules.attention.print = lambda *args: None
sgm.modules.diffusionmodules.model.print = lambda *args: None
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
diff --git a/modules/shared.py b/modules/shared.py
index 71afd94f..234ede0d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
- "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"),
- "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"),
+}))
+
+options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
+ "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
+ "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
+ "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
+ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {