diff options
Diffstat (limited to 'modules/sd_hijack_checkpoint.py')
-rw-r--r-- | modules/sd_hijack_checkpoint.py | 38 |
1 files changed, 37 insertions, 1 deletions
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 5712972f..2604d969 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -1,10 +1,46 @@ from torch.utils.checkpoint import checkpoint +import ldm.modules.attention +import ldm.modules.diffusionmodules.openaimodel + + def BasicTransformerBlock_forward(self, x, context=None): return checkpoint(self._forward, x, context) + def AttentionBlock_forward(self, x): return checkpoint(self._forward, x) + def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb)
\ No newline at end of file + return checkpoint(self._forward, x, emb) + + +stored = [] + + +def add(): + if len(stored) != 0: + return + + stored.extend([ + ldm.modules.attention.BasicTransformerBlock.forward, + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward + ]) + + ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward + + +def remove(): + if len(stored) == 0: + return + + ldm.modules.attention.BasicTransformerBlock.forward = stored[0] + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] + + stored.clear() + |