aboutsummaryrefslogtreecommitdiff
path: root/modules/gfpgan_model.py
blob: 01d668ecdaff759b15a6d0d29fccbc92fa0c8156 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os

import facexlib
import gfpgan

import modules.face_restoration
from modules import paths, shared, devices, modelloader, errors

model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(paths.models_path, model_dir)
model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None


def gfpgann():
    global loaded_gfpgan_model
    global model_path
    global model_file_path
    if loaded_gfpgan_model is not None:
        loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
        return loaded_gfpgan_model

    if gfpgan_constructor is None:
        return None

    models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])

    if len(models) == 1 and models[0].startswith("http"):
        model_file = models[0]
    elif len(models) != 0:
        gfp_models = []
        for item in models:
            if 'GFPGAN' in os.path.basename(item):
                gfp_models.append(item)
        latest_file = max(gfp_models, key=os.path.getctime)
        model_file = latest_file
    else:
        print("Unable to load gfpgan model!")
        return None

    if hasattr(facexlib.detection.retinaface, 'device'):
        facexlib.detection.retinaface.device = devices.device_gfpgan
    model_file_path = model_file
    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
    loaded_gfpgan_model = model

    return model


def send_model_to(model, device):
    model.gfpgan.to(device)
    model.face_helper.face_det.to(device)
    model.face_helper.face_parse.to(device)


def gfpgan_fix_faces(np_image):
    model = gfpgann()
    if model is None:
        return np_image

    send_model_to(model, devices.device_gfpgan)

    np_image_bgr = np_image[:, :, ::-1]
    cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
    np_image = gfpgan_output_bgr[:, :, ::-1]

    model.face_helper.clean_all()

    if shared.opts.face_restoration_unload:
        send_model_to(model, devices.cpu)

    return np_image


gfpgan_constructor = None


def setup_model(dirname):
    try:
        os.makedirs(model_path, exist_ok=True)
        from gfpgan import GFPGANer
        from facexlib import detection, parsing  # noqa: F401
        global user_path
        global have_gfpgan
        global gfpgan_constructor
        global model_file_path

        facexlib_path = model_path

        if dirname is not None:
            facexlib_path = dirname

        load_file_from_url_orig = gfpgan.utils.load_file_from_url
        facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
        facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url

        def my_load_file_from_url(**kwargs):
            return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))

        def facex_load_file_from_url(**kwargs):
            return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))

        def facex_load_file_from_url2(**kwargs):
            return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))

        gfpgan.utils.load_file_from_url = my_load_file_from_url
        facexlib.detection.load_file_from_url = facex_load_file_from_url
        facexlib.parsing.load_file_from_url = facex_load_file_from_url2
        user_path = dirname
        have_gfpgan = True
        gfpgan_constructor = GFPGANer

        class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
            def name(self):
                return "GFPGAN"

            def restore(self, np_image):
                return gfpgan_fix_faces(np_image)

        shared.face_restorers.append(FaceRestorerGFPGAN())
    except Exception:
        errors.report("Error setting up GFPGAN", exc_info=True)