aboutsummaryrefslogtreecommitdiff
path: root/scripts/postprocessing_caption.py
blob: 5592a89870e278f212cfc0102c9fda736dbed08d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from modules import scripts_postprocessing, ui_components, deepbooru, shared
import gradio as gr


class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
    name = "Caption"
    order = 4040

    def ui(self):
        with ui_components.InputAccordion(False, label="Caption") as enable:
            option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)

        return {
            "enable": enable,
            "option": option,
        }

    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
        if not enable:
            return

        captions = [pp.caption]

        if "Deepbooru" in option:
            captions.append(deepbooru.model.tag(pp.image))

        if "BLIP" in option:
            captions.append(shared.interrogator.interrogate(pp.image.convert("RGB")))

        pp.caption = ", ".join([x for x in captions if x])