aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorShondoit <shondoit@gmail.com>2023-01-12 16:29:00 +0100
committerShondoit <shondoit@gmail.com>2023-02-15 10:03:59 +0100
commitedb10092de516dda5271130ed53628387780a859 (patch)
treed7f78d9a5b238b3836f9811327a4160b59ef14eb /modules/textual_inversion/dataset.py
parentbc50936745e1a349afdc28cf1540109ba20bc71a (diff)
Add ability to choose using weighted loss or not
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index f4ce4552..1568b2b8 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -31,7 +31,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -64,7 +64,7 @@ class PersonalizedBase(Dataset):
image = Image.open(path)
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
- if 'A' in image.getbands():
+ if use_weight and 'A' in image.getbands():
alpha_channel = image.getchannel('A')
image = image.convert('RGB')
if not varsize:
@@ -104,7 +104,7 @@ class PersonalizedBase(Dataset):
latent_sampling_method = "once"
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- if alpha_channel is not None:
+ if use_weight and alpha_channel is not None:
channels, *latent_size = latent_sample.shape
weight_img = alpha_channel.resize(latent_size)
npweight = np.array(weight_img).astype(np.float32)
@@ -113,9 +113,11 @@ class PersonalizedBase(Dataset):
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight -= weight.min()
weight /= weight.mean()
- else:
+ elif use_weight:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight = torch.ones([channels] + latent_size)
+ else:
+ weight = None
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
@@ -219,7 +221,10 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
- self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
+ if all(entry.weight is not None for entry in data):
+ self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
+ else:
+ self.weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)