aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJay Smith <jayvsmith@gmail.com>2022-12-08 18:14:35 -0600
committerJay Smith <jayvsmith@gmail.com>2022-12-08 20:50:08 -0600
commit1ed4f0e22807f3afef925210182cbbee51f0cb2c (patch)
tree5ca841b153e2b0f8b733a178c6c44c721f30a256
parent44c46f0ed395967cd3830dd481a2db759fda5b3b (diff)
Depth2img model support
-rw-r--r--README.md1
-rw-r--r--modules/processing.py38
-rw-r--r--modules/sd_models.py46
3 files changed, 81 insertions, 4 deletions
diff --git a/README.md b/README.md
index 8a4ffade..55990581 100644
--- a/README.md
+++ b/README.md
@@ -135,6 +135,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- SwinIR - https://github.com/JingyunLiang/SwinIR
- Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion
+- MiDaS - https://github.com/isl-org/MiDaS
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
diff --git a/modules/processing.py b/modules/processing.py
index 3d2c4dc9..0417ffc5 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -21,7 +21,10 @@ import modules.face_restoration
import modules.images as images
import modules.styles
import logging
+from ldm.data.util import AddMiDaS
+from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
+from einops import repeat, rearrange
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -150,11 +153,26 @@ class StableDiffusionProcessing():
return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+ def depth2img_image_conditioning(self, source_image):
+ # Use the AddMiDaS helper to Format our source image to suit the MiDaS model
+ transformer = AddMiDaS(model_type="dpt_hybrid")
+ transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
+ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
+ midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
+
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning = torch.nn.functional.interpolate(
+ self.sd_model.depth_model(midas_in),
+ size=conditioning_image.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ (depth_min, depth_max) = torch.aminmax(conditioning)
+ conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
+ return conditioning
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -191,6 +209,18 @@ class StableDiffusionProcessing():
return image_conditioning
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
+ # identify itself with a field common to all models. The conditioning_key is also hybrid.
+ if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
+ return self.depth2img_image_conditioning(source_image)
+
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
+ # Dummy zero conditioning if we're not using inpainting or depth model.
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 283cf1cd..139952ba 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,6 +7,9 @@ import torch
import re
import safetensors.torch
from omegaconf import OmegaConf
+from os import mkdir
+from urllib import request
+import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
@@ -36,6 +39,7 @@ def setup_model():
os.makedirs(model_path)
list_models()
+ enable_midas_autodownload()
def checkpoint_tiles():
@@ -227,6 +231,48 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
sd_vae.load_vae(model, vae_file)
+def enable_midas_autodownload():
+ """
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
+
+ When the 512-depth-ema model, and other future models like it, is loaded,
+ it calls midas.api.load_model to load the associated midas depth model.
+ This function applies a wrapper to download the model to the correct
+ location automatically.
+ """
+
+ midas_path = os.path.join(models_path, 'midas')
+
+ # stable-diffusion-stability-ai hard-codes the midas model path to
+ # a location that differs from where other scripts using this model look.
+ # HACK: Overriding the path here.
+ for k, v in midas.api.ISL_PATHS.items():
+ file_name = os.path.basename(v)
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
+
+ midas_urls = {
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
+ }
+
+ midas.api.load_model_inner = midas.api.load_model
+
+ def load_model_wrapper(model_type):
+ path = midas.api.ISL_PATHS[model_type]
+ if not os.path.exists(path):
+ if not os.path.exists(midas_path):
+ mkdir(midas_path)
+
+ print(f"Downloading midas model weights for {model_type} to {path}")
+ request.urlretrieve(midas_urls[model_type], path)
+ print(f"{model_type} downloaded")
+
+ return midas.api.load_model_inner(model_type)
+
+ midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()