from torch.utils.checkpoint import checkpoint def BasicTransformerBlock_forward(self, x, context=None): return checkpoint(self._forward, x, context) def AttentionBlock_forward(self, x): return checkpoint(self._forward, x) def ResBlock_forward(self, x, emb): return checkpoint(self._forward, x, emb)