aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-12-31 18:06:35 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-12-31 18:06:35 +0300
commitf34c7341720fb2059992926c9f9ae6ff25f7385b (patch)
treebe719a629f8754c206d891b1850f0b5eaf186e2e
parent3f401cdb644066fd43abf6642d2e53be53c73668 (diff)
alt-diffusion integration
-rw-r--r--configs/alt-diffusion-inference.yaml (renamed from configs/altdiffusion/ad-inference.yaml)0
-rw-r--r--configs/v1-inference.yaml (renamed from v1-inference.yaml)0
-rw-r--r--modules/sd_hijack.py18
-rw-r--r--modules/sd_hijack_clip.py14
-rw-r--r--modules/sd_hijack_xlmr.py34
-rw-r--r--modules/shared.py10
6 files changed, 50 insertions, 26 deletions
diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/alt-diffusion-inference.yaml
index cfbee72d..cfbee72d 100644
--- a/configs/altdiffusion/ad-inference.yaml
+++ b/configs/alt-diffusion-inference.yaml
diff --git a/v1-inference.yaml b/configs/v1-inference.yaml
index d4effe56..d4effe56 100644
--- a/v1-inference.yaml
+++ b/configs/v1-inference.yaml
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bce23b03..edcbaf52 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -68,6 +68,7 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
+
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -79,21 +80,22 @@ class StableDiffusionModelHijack:
def hijack(self, m):
- if shared.text_model_name == "XLMR-Large":
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_optimizations()
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_optimizations()
-
+
+ apply_optimizations()
+
self.clip = m.cond_stage_model
fix_checkpoint()
@@ -109,7 +111,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m):
- if shared.text_model_name == "XLMR-Large":
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 9ea6e1ce..6ec50cca 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -4,7 +4,6 @@ import torch
from modules import prompt_parser, devices
from modules.shared import opts
-import modules.shared as shared
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
@@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
- if shared.text_model_name == "XLMR-Large":
- return self.wrapped.encode(text)
-
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- if shared.text_model_name == "XLMR-Large":
- self.comma_token = None
- else :
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
new file mode 100644
index 00000000..4ac51c38
--- /dev/null
+++ b/modules/sd_hijack_xlmr.py
@@ -0,0 +1,34 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
diff --git a/modules/shared.py b/modules/shared.py
index 2b31e717..715b9169 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -23,7 +23,7 @@ demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@@ -108,14 +108,6 @@ restricted_opts = {
"outdir_txt2img_grids",
"outdir_save",
}
-from omegaconf import OmegaConf
-config = OmegaConf.load(f"{cmd_opts.config}")
-# XLMR-Large
-try:
- text_model_name = config.model.params.cond_stage_config.params.name
-
-except :
- text_model_name = "stable_diffusion"
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access