From 21642000b33a3069e3408ea1a50239006176badb Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 15:29:19 +0100 Subject: Add PNG alpha channel as weight maps to data entries --- modules/textual_inversion/dataset.py | 51 +++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 13 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index d31963d4..f4ce4552 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None): self.filename = filename self.filename_text = filename_text + self.weight = weight self.latent_dist = latent_dist self.latent_sample = latent_sample self.cond = cond @@ -56,10 +57,16 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): + alpha_channel = None if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB') + image = Image.open(path) + #Currently does not work for single color transparency + #We would need to read image.info['transparency'] for that + if 'A' in image.getbands(): + alpha_channel = image.getchannel('A') + image = image.convert('RGB') if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) except Exception: @@ -87,17 +94,33 @@ class PersonalizedBase(Dataset): with devices.autocast(): latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) - if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) - elif latent_sampling_method == "deterministic": - # Works only for DiagonalGaussianDistribution - latent_dist.std = 0 - latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) - elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + #Perform latent sampling, even for random sampling. + #We need the sample dimensions for the weights + if latent_sampling_method == "deterministic": + if isinstance(latent_dist, DiagonalGaussianDistribution): + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + else: + latent_sampling_method = "once" + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + + if alpha_channel is not None: + channels, *latent_size = latent_sample.shape + weight_img = alpha_channel.resize(latent_size) + npweight = np.array(weight_img).astype(np.float32) + #Repeat for every channel in the latent sample + weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size) + #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. + weight -= weight.min() + weight /= weight.mean() + else: + #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later + weight = torch.ones([channels] + latent_size) + + if latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) + else: + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -110,6 +133,7 @@ class PersonalizedBase(Dataset): del torchdata del latent_dist del latent_sample + del weight self.length = len(self.dataset) self.groups = list(groups.values()) @@ -195,6 +219,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) -- cgit v1.2.1 From bc50936745e1a349afdc28cf1540109ba20bc71a Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 15:34:11 +0100 Subject: Call weighted_forward during training --- modules/textual_inversion/textual_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index a1a406c2..8853c868 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -480,6 +480,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + w = batch.weight.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) if is_training_inpainting_model: @@ -490,7 +491,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st else: cond = c - loss = shared.sd_model(x, cond)[0] / gradient_step + loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step del x _loss_step += loss.item() -- cgit v1.2.1 From edb10092de516dda5271130ed53628387780a859 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 16:29:00 +0100 Subject: Add ability to choose using weighted loss or not --- modules/textual_inversion/dataset.py | 15 ++++++++++----- modules/textual_inversion/textual_inversion.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f4ce4552..1568b2b8 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -31,7 +31,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -64,7 +64,7 @@ class PersonalizedBase(Dataset): image = Image.open(path) #Currently does not work for single color transparency #We would need to read image.info['transparency'] for that - if 'A' in image.getbands(): + if use_weight and 'A' in image.getbands(): alpha_channel = image.getchannel('A') image = image.convert('RGB') if not varsize: @@ -104,7 +104,7 @@ class PersonalizedBase(Dataset): latent_sampling_method = "once" latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - if alpha_channel is not None: + if use_weight and alpha_channel is not None: channels, *latent_size = latent_sample.shape weight_img = alpha_channel.resize(latent_size) npweight = np.array(weight_img).astype(np.float32) @@ -113,9 +113,11 @@ class PersonalizedBase(Dataset): #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. weight -= weight.min() weight /= weight.mean() - else: + elif use_weight: #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later weight = torch.ones([channels] + latent_size) + else: + weight = None if latent_sampling_method == "random": entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) @@ -219,7 +221,10 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) + if all(entry.weight is not None for entry in data): + self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) + else: + self.weight = None #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8853c868..c63c7d1d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -351,7 +351,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) @@ -410,7 +410,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight) if shared.opts.save_training_settings_to_txt: save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) @@ -480,7 +480,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) - w = batch.weight.to(devices.device, non_blocking=pin_memory) + if use_weight: + w = batch.weight.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) if is_training_inpainting_model: @@ -491,7 +492,11 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st else: cond = c - loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step + if use_weight: + loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step + del w + else: + loss = shared.sd_model.forward(x, cond)[0] / gradient_step del x _loss_step += loss.item() -- cgit v1.2.1 From 11183b4d905d14c6a0164a4d13675b89b1bf4ceb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 19 Feb 2023 12:44:56 +0300 Subject: fix for #6700 --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 1568b2b8..af9fbcf2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -115,7 +115,7 @@ class PersonalizedBase(Dataset): weight /= weight.mean() elif use_weight: #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later - weight = torch.ones([channels] + latent_size) + weight = torch.ones(latent_sample.shape) else: weight = None -- cgit v1.2.1