aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
blob: f61f40d30d1c5ce54a1d8da542f5504d6c296b27 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

import random
import tqdm
from modules import devices, shared
import re

re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")


class PersonalizedBase(Dataset):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):

        self.placeholder_token = placeholder_token

        self.width = width
        self.height = height
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

        self.dataset = []

        with open(template_file, "r") as file:
            lines = [x.strip() for x in file.readlines()]

        self.lines = lines

        assert data_root, 'dataset directory not specified'

        cond_model = shared.sd_model.cond_stage_model

        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
            try:
                image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
            except Exception:
                continue

            filename = os.path.basename(path)
            filename_tokens = os.path.splitext(filename)[0]
            filename_tokens = re_tag.findall(filename_tokens)

            npimage = np.array(image).astype(np.uint8)
            npimage = (npimage / 127.5 - 1.0).astype(np.float32)

            torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
            torchdata = torch.moveaxis(torchdata, 2, 0)

            init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
            init_latent = init_latent.to(devices.cpu)

            if include_cond:
                text = self.create_text(filename_tokens)
                cond = cond_model([text]).to(devices.cpu)
            else:
                cond = None

            self.dataset.append((init_latent, filename_tokens, cond))

        self.length = len(self.dataset) * repeats

        self.initial_indexes = np.arange(self.length) % len(self.dataset)
        self.indexes = None
        self.shuffle()

    def shuffle(self):
        self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]

    def create_text(self, filename_tokens):
        text = random.choice(self.lines)
        text = text.replace("[name]", self.placeholder_token)
        text = text.replace("[filewords]", ' '.join(filename_tokens))
        return text

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        if i % len(self.dataset) == 0:
            self.shuffle()

        index = self.indexes[i % len(self.indexes)]
        x, filename_tokens, cond = self.dataset[index]

        text = self.create_text(filename_tokens)
        return x, text, cond