aboutsummaryrefslogtreecommitdiff
path: root/modules/aesthetic_clip.py
blob: ccb35c73fd49c93f85edcd8d8b7fd285059c461b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import itertools
import os
from pathlib import Path
import html
import gc

import gradio as gr
import torch
from PIL import Image
from modules import shared
from modules.shared import device
from transformers import CLIPModel, CLIPProcessor

from tqdm.auto import tqdm


def get_all_images_in_folder(folder):
    return [os.path.join(folder, f) for f in os.listdir(folder) if
            os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]


def check_is_valid_image_file(filename):
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))


def batched(dataset, total, n=1):
    for ndx in range(0, total, n):
        yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]


def iter_to_batched(iterable, n=1):
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk


def generate_imgs_embd(name, folder, batch_size):
    # clipModel = CLIPModel.from_pretrained(
    #     shared.sd_model.cond_stage_model.clipModel.name_or_path
    # )
    model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
    processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)

    with torch.no_grad():
        embs = []
        for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
                          desc=f"Generating embeddings for {name}"):
            if shared.state.interrupted:
                break
            inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
            outputs = model.get_image_features(**inputs).cpu()
            embs.append(torch.clone(outputs))
            inputs.to("cpu")
            del inputs, outputs

        embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)

        # The generated embedding will be located here
        path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
        torch.save(embs, path)

        model = model.cpu()
        del model
        del processor
        del embs
        gc.collect()
        torch.cuda.empty_cache()
        res = f"""
        Done generating embedding for {name}!
        Aesthetic embedding saved to {html.escape(path)}
        """
        shared.update_aesthetic_embeddings()
        return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
                           value="None"), res, ""