aboutsummaryrefslogtreecommitdiff
path: root/modules/gfpgan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/gfpgan_model.py')
-rw-r--r--modules/gfpgan_model.py58
1 files changed, 40 insertions, 18 deletions
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index ffb6960d..2bf8a1ee 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,24 +1,23 @@
import os
import sys
import traceback
-from glob import glob
-from modules import shared, devices
-from modules.shared import cmd_opts
-from modules.paths import script_path
+import facexlib
+import gfpgan
+
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
-cmd_dir = None
+user_path = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-
+have_gfpgan = False
loaded_gfpgan_model = None
-def gfpgan():
+def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
@@ -28,14 +27,16 @@ def gfpgan():
if gfpgan_constructor is None:
return None
- models = modelloader.load_models(model_path, model_url, cmd_dir)
- if len(models) != 0:
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
- model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
@@ -44,11 +45,12 @@ def gfpgan():
def gfpgan_fix_faces(np_image):
- model = gfpgan()
+ model = gfpgann()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
@@ -57,7 +59,6 @@ def gfpgan_fix_faces(np_image):
return np_image
-have_gfpgan = False
gfpgan_constructor = None
@@ -67,14 +68,33 @@ def setup_model(dirname):
os.makedirs(model_path)
try:
- from modules.gfpgan_model_arch import GFPGANerr
- global cmd_dir
+ from gfpgan import GFPGANer
+ from facexlib import detection, parsing
+ global user_path
global have_gfpgan
global gfpgan_constructor
- cmd_dir = dirname
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
+ def my_load_file_from_url(**kwargs):
+ print("Setting model_dir to " + model_path)
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+
+ def facex_load_file_from_url(**kwargs):
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ def facex_load_file_from_url2(**kwargs):
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+ user_path = dirname
+ print("Have gfpgan should be true?")
have_gfpgan = True
- gfpgan_constructor = GFPGANerr
+ gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
@@ -82,7 +102,9 @@ def setup_model(dirname):
def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False,
+ paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image