aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-26 14:48:43 +0300
committerGitHub <noreply@github.com>2023-01-26 14:48:43 +0300
commit645f4e7ef8c9d59deea7091a22373b2da2b780f2 (patch)
tree4a78e30ec8cafd14ecbb677485d2d91282d465ad /modules
parent6cff4401824299a983c8e13424018efc347b4a2b (diff)
parent10421f93c3f7f7ce88cb40391b46d4e6664eff74 (diff)
Merge pull request #7234 from brkirch/fix-full-previews
Fix full previews and--no-half-vae to work correctly with --upcast-sampling
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py8
-rw-r--r--modules/sd_hijack_utils.py2
2 files changed, 5 insertions, 5 deletions
diff --git a/modules/processing.py b/modules/processing.py
index cb41288a..92894d67 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -172,7 +172,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
@@ -217,7 +217,7 @@ class StableDiffusionProcessing:
)
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
+ x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
return x
@@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
+ image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
index f81b169a..f8684475 100644
--- a/modules/sd_hijack_utils.py
+++ b/modules/sd_hijack_utils.py
@@ -5,7 +5,7 @@ class CondFunc:
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
- for i in range(len(func_path)-2, -1, -1):
+ for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break