aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_checkpoint.py
blob: 5712972f118093a81158dddc279b4db0a338fca2 (plain)
1
2
3
4
5
6
7
8
9
10
from torch.utils.checkpoint import checkpoint

def BasicTransformerBlock_forward(self, x, context=None):
    return checkpoint(self._forward, x, context)

def AttentionBlock_forward(self, x):
    return checkpoint(self._forward, x)

def ResBlock_forward(self, x, emb):
    return checkpoint(self._forward, x, emb)