aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_checkpoint.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-18 23:04:24 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-18 23:04:24 +0300
commit924e222004ab54273806c5f2ca7a0e7cfa76ad83 (patch)
tree153a08105ee2bc87df43a8a1423df96d25a8e19b /modules/sd_hijack_checkpoint.py
parent889b851a5260ce869a3286ad15d17d1bbb1da0a7 (diff)
add option to show/hide warnings
removed hiding warnings from LDSR fixed/reworked few places that produced warnings
Diffstat (limited to 'modules/sd_hijack_checkpoint.py')
-rw-r--r--modules/sd_hijack_checkpoint.py38
1 files changed, 37 insertions, 1 deletions
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
index 5712972f..2604d969 100644
--- a/modules/sd_hijack_checkpoint.py
+++ b/modules/sd_hijack_checkpoint.py
@@ -1,10 +1,46 @@
from torch.utils.checkpoint import checkpoint
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.openaimodel
+
+
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
+
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
+
def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb) \ No newline at end of file
+ return checkpoint(self._forward, x, emb)
+
+
+stored = []
+
+
+def add():
+ if len(stored) != 0:
+ return
+
+ stored.extend([
+ ldm.modules.attention.BasicTransformerBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
+ ])
+
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
+
+
+def remove():
+ if len(stored) == 0:
+ return
+
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
+
+ stored.clear()
+