aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-27 10:19:43 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-01-28 04:16:25 -0500
commitada17dbd7c4c68a4e559848d2e6f2a7799722806 (patch)
treeced66b899aba64a4e5d7b66a3bc8cdb796e0cf16 /modules
parentc4b9b07db6272768428fa8efeb7d7a9f22eca0b1 (diff)
Refactor conditional casting, fix upscalers
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py8
-rw-r--r--modules/processing.py15
-rw-r--r--modules/realesrgan_model.py2
-rw-r--r--modules/sd_hijack.py2
-rw-r--r--modules/sd_hijack_unet.py8
5 files changed, 25 insertions, 10 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 6b36622c..0100e4af 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -83,6 +83,14 @@ dtype_unet = torch.float16
unet_needs_upcast = False
+def cond_cast_unet(input):
+ return input.to(dtype_unet) if unet_needs_upcast else input
+
+
+def cond_cast_float(input):
+ return input.float() if unet_needs_upcast else input
+
+
def randn(seed, shape):
torch.manual_seed(seed)
if device.type == 'mps':
diff --git a/modules/processing.py b/modules/processing.py
index 92894d67..a397702b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -172,8 +172,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
- conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@@ -217,7 +216,7 @@ class StableDiffusionProcessing:
)
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -228,16 +227,18 @@ class StableDiffusionProcessing:
return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ source_image = devices.cond_cast_float(source_image)
+
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
- return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
+ return self.depth2img_image_conditioning(source_image)
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -417,7 +418,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
+ x = model.decode_first_stage(x)
return x
@@ -1001,7 +1002,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
+ image = image.to(shared.device)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 47f70251..aad4a629 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
scale=info.scale,
model_path=info.local_data_path,
model=info.model(),
- half=not cmd_opts.no_half,
+ half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 531790f3..8fc91882 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
- emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec
+ emb = devices.cond_cast_unet(embedding.vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index a6ee577c..45cf2b18 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
+
+first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
+first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)