aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-11-30 11:14:04 +0800
committerGitHub <noreply@github.com>2022-11-30 11:14:04 +0800
commita39a57cb1f5964d9af2b541f7b352576adeeac0f (patch)
treeebae98ea40ecc5b34497424bee19310e9fac4068 /modules
parent4b3c5bc24bffdf429c463a465763b3077fe55eb8 (diff)
parent0831ab476c626eb796b609acf8771177692bfab7 (diff)
Merge pull request #1 from 920232796/master
Add AltDiffusion
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py4
-rw-r--r--modules/sd_hijack.py14
-rw-r--r--modules/shared.py6
3 files changed, 17 insertions, 7 deletions
diff --git a/modules/devices.py b/modules/devices.py
index f00079c6..e69c1fe3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -38,8 +38,8 @@ def get_optimal_device():
if torch.cuda.is_available():
return torch.device(get_cuda_device_string())
- if has_mps():
- return torch.device("mps")
+ # if has_mps():
+ # return torch.device("mps")
return cpu
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index b824b5bf..3ec3f98a 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -81,17 +81,23 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
+
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ apply_optimizations()
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
+ apply_optimizations()
+ elif shared.text_model_name == "XLMR-Large":
+ model_embeddings = m.cond_stage_model.roberta.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
self.clip = m.cond_stage_model
-
- apply_optimizations()
+
fix_checkpoint()
def flatten(el):
@@ -132,8 +138,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
diff --git a/modules/shared.py b/modules/shared.py
index c36ee211..1408dee3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -22,7 +22,7 @@ demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default="configs/altdiffusion/ad-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@@ -108,6 +108,10 @@ restricted_opts = {
"outdir_txt2img_grids",
"outdir_save",
}
+from omegaconf import OmegaConf
+config = OmegaConf.load(f"{cmd_opts.config}")
+# XLMR-Large
+text_model_name = config.model.params.cond_stage_config.params.name
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access