aboutsummaryrefslogtreecommitdiff
path: root/scripts/postprocessing_focal_crop.py
blob: d3baf29878a79c3f06a7ca5f06fc0b8695de8741 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

from modules import scripts_postprocessing, ui_components, errors
import gradio as gr

from modules.textual_inversion import autocrop


class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
    name = "Auto focal point crop"
    order = 4000

    def ui(self):
        with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
            face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
            entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
            edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
            debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")

        return {
            "enable": enable,
            "face_weight": face_weight,
            "entropy_weight": entropy_weight,
            "edges_weight": edges_weight,
            "debug": debug,
        }

    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
        if not enable:
            return

        if not pp.shared.target_width or not pp.shared.target_height:
            return

        dnn_model_path = None
        try:
            dnn_model_path = autocrop.download_and_cache_models()
        except Exception:
            errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)

        autocrop_settings = autocrop.Settings(
            crop_width=pp.shared.target_width,
            crop_height=pp.shared.target_height,
            face_points_weight=face_weight,
            entropy_points_weight=entropy_weight,
            corner_points_weight=edges_weight,
            annotate_image=debug,
            dnn_model_path=dnn_model_path,
        )

        result, *others = autocrop.crop_image(pp.image, autocrop_settings)

        pp.image = result
        pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]