aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorJairo Correa <jn.j41r0@gmail.com>2022-09-28 22:14:13 -0300
committerJairo Correa <jn.j41r0@gmail.com>2022-09-28 22:14:13 -0300
commitc938679de7b87b4f14894d9f57fe0f40dd6e3c06 (patch)
tree31119c9dc5d04648bf873069c690b4fd53dc3805 /modules
parent041d2aefc082c2883aa7e28ee3e4a990b3be9758 (diff)
Fix memory leak and reduce memory usage
Diffstat (limited to 'modules')
-rw-r--r--modules/codeformer_model.py6
-rw-r--r--modules/devices.py3
-rw-r--r--modules/extras.py2
-rw-r--r--modules/gfpgan_model.py11
-rw-r--r--modules/processing.py33
5 files changed, 39 insertions, 16 deletions
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8fbdea24..2177291a 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -89,7 +89,7 @@ def setup_codeformer():
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
- torch.cuda.empty_cache()
+ devices.torch_gc()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -106,7 +106,9 @@ def setup_codeformer():
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
if shared.opts.face_restoration_unload:
- self.net.to(devices.cpu)
+ self.net = None
+ self.face_helper = None
+ devices.torch_gc()
return restored_img
diff --git a/modules/devices.py b/modules/devices.py
index 07bb2339..df63dd88 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,4 +1,5 @@
import torch
+import gc
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors
@@ -17,8 +18,8 @@ def get_optimal_device():
return cpu
-
def torch_gc():
+ gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
diff --git a/modules/extras.py b/modules/extras.py
index 9a825530..38b86167 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -98,6 +98,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs.append(image)
+ devices.torch_gc()
+
return outputs, plaintext_to_html(info), ''
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 44c5dc6c..b1288f0c 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -49,6 +49,7 @@ def gfpgan():
def gfpgan_fix_faces(np_image):
+ global loaded_gfpgan_model
model = gfpgan()
np_image_bgr = np_image[:, :, ::-1]
@@ -56,7 +57,9 @@ def gfpgan_fix_faces(np_image):
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
- model.gfpgan.to(devices.cpu)
+ del model
+ loaded_gfpgan_model = None
+ devices.torch_gc()
return np_image
@@ -83,11 +86,7 @@ def setup_gfpgan():
return "GFPGAN"
def restore(self, np_image):
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- return np_image
+ return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
diff --git a/modules/processing.py b/modules/processing.py
index 4ecdfcd2..de5cda79 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ import cv2
from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking
+from modules import devices, prompt_parser, masking, lowvram
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
- for n in range(p.n_iter):
+ for n in range(p.n_iter):
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
if state.interrupted:
break
@@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+
+ devices.torch_gc()
+
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
- for i, x_sample in enumerate(x_samples_ddim):
+ for i, x_sample in enumerate(x_samples_ddim):
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
- if p.restore_faces:
+ if p.restore_faces:
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
- devices.torch_gc()
-
x_sample = modules.face_restoration.restore_faces(x_sample)
+ devices.torch_gc()
+
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
image = Image.fromarray(x_sample)
if p.color_corrections is not None and i < len(p.color_corrections):
@@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts.append(infotext(n, i))
output_images.append(image)
- state.nextjob()
+ del x_samples_ddim
+ devices.torch_gc()
+
+ state.nextjob()
+
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
p.color_corrections = None
index_of_first_image = 0
@@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+
return samples