From 59544321aa019d71d220b1da1eec703aa44fa8eb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 11 Sep 2023 21:17:28 +0300 Subject: initial work on sd_unet for SDXL --- modules/sd_unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_unet.py') diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 5525cfbc..6a7bc9e2 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -1,11 +1,11 @@ import torch.nn -import ldm.modules.diffusionmodules.openaimodel from modules import script_callbacks, shared, devices unet_options = [] current_unet_option = None current_unet = None +original_forward = None def list_unets(): @@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): if current_unet is not None: return current_unet.forward(x, timesteps, context, *args, **kwargs) - return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) + return original_forward(self, x, timesteps, context, *args, **kwargs) -- cgit v1.2.1 From ac02216e540cd581f9169c6c791e55721e3117b0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Dec 2023 19:35:47 +0300 Subject: alternate implementation for unet forward replacement that does not depend on hijack being applied --- modules/sd_unet.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'modules/sd_unet.py') diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 6a7bc9e2..a771849c 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices unet_options = [] current_unet_option = None current_unet = None -original_forward = None - +original_forward = None # not used, only left temporarily for compatibility def list_unets(): new_unets = script_callbacks.list_unets_callback() @@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module): pass -def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): - if current_unet is not None: - return current_unet.forward(x, timesteps, context, *args, **kwargs) +def create_unet_forward(original_forward): + def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) + + return original_forward(self, x, timesteps, context, *args, **kwargs) - return original_forward(self, x, timesteps, context, *args, **kwargs) + return UNetModel_forward -- cgit v1.2.1