From bd68e35de3b7cf7547ed97d8bdf60147402133cc Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 12:35:26 +0900 Subject: Gradient accumulation, autocast fix, new latent sampling method, etc --- modules/sd_hijack_checkpoint.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 modules/sd_hijack_checkpoint.py (limited to 'modules/sd_hijack_checkpoint.py') diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py new file mode 100644 index 00000000..5712972f --- /dev/null +++ b/modules/sd_hijack_checkpoint.py @@ -0,0 +1,10 @@ +from torch.utils.checkpoint import checkpoint + +def BasicTransformerBlock_forward(self, x, context=None): + return checkpoint(self._forward, x, context) + +def AttentionBlock_forward(self, x): + return checkpoint(self._forward, x) + +def ResBlock_forward(self, x, emb): + return checkpoint(self._forward, x, emb) \ No newline at end of file -- cgit v1.2.1