aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py58
1 files changed, 44 insertions, 14 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index d31963d4..af9fbcf2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
self.filename = filename
self.filename_text = filename_text
+ self.weight = weight
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
@@ -30,7 +31,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
+ alpha_channel = None
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB')
+ image = Image.open(path)
+ #Currently does not work for single color transparency
+ #We would need to read image.info['transparency'] for that
+ if use_weight and 'A' in image.getbands():
+ alpha_channel = image.getchannel('A')
+ image = image.convert('RGB')
if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
@@ -87,17 +94,35 @@ class PersonalizedBase(Dataset):
with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
- if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- latent_sampling_method = "once"
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
- elif latent_sampling_method == "deterministic":
- # Works only for DiagonalGaussianDistribution
- latent_dist.std = 0
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
- elif latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+ #Perform latent sampling, even for random sampling.
+ #We need the sample dimensions for the weights
+ if latent_sampling_method == "deterministic":
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
+ # Works only for DiagonalGaussianDistribution
+ latent_dist.std = 0
+ else:
+ latent_sampling_method = "once"
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+
+ if use_weight and alpha_channel is not None:
+ channels, *latent_size = latent_sample.shape
+ weight_img = alpha_channel.resize(latent_size)
+ npweight = np.array(weight_img).astype(np.float32)
+ #Repeat for every channel in the latent sample
+ weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
+ #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
+ weight -= weight.min()
+ weight /= weight.mean()
+ elif use_weight:
+ #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
+ weight = torch.ones(latent_sample.shape)
+ else:
+ weight = None
+
+ if latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
+ else:
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
@@ -110,6 +135,7 @@ class PersonalizedBase(Dataset):
del torchdata
del latent_dist
del latent_sample
+ del weight
self.length = len(self.dataset)
self.groups = list(groups.values())
@@ -195,6 +221,10 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ if all(entry.weight is not None for entry in data):
+ self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
+ else:
+ self.weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)