aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
blob: 7c44ea5be5a51d86fb990bf74306f067a7c9f1a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

import random
import tqdm
from modules import devices
import re

re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")


class PersonalizedBase(Dataset):
    def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):

        self.placeholder_token = placeholder_token

        self.size = size
        self.width = width
        self.height = height
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

        self.dataset = []

        with open(template_file, "r") as file:
            lines = [x.strip() for x in file.readlines()]

        self.lines = lines

        assert data_root, 'dataset directory not specified'

        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
            image = Image.open(path)
            image = image.convert('RGB')
            image = image.resize((self.width, self.height), PIL.Image.BICUBIC)

            filename = os.path.basename(path)
            filename_tokens = os.path.splitext(filename)[0]
            filename_tokens = re_tag.findall(filename_tokens)

            npimage = np.array(image).astype(np.uint8)
            npimage = (npimage / 127.5 - 1.0).astype(np.float32)

            torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
            torchdata = torch.moveaxis(torchdata, 2, 0)

            init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
            init_latent = init_latent.to(devices.cpu)

            self.dataset.append((init_latent, filename_tokens))

        self.length = len(self.dataset) * repeats

        self.initial_indexes = np.arange(self.length) % len(self.dataset)
        self.indexes = None
        self.shuffle()

    def shuffle(self):
        self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        if i % len(self.dataset) == 0:
            self.shuffle()

        index = self.indexes[i % len(self.indexes)]
        x, filename_tokens = self.dataset[index]

        text = random.choice(self.lines)
        text = text.replace("[name]", self.placeholder_token)
        text = text.replace("[filewords]", ' '.join(filename_tokens))

        return x, text