From 52cc83d36b7663a77b79fd2258d2ca871af73e55 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Wed, 30 Nov 2022 14:56:12 +0800 Subject: fix bugs Signed-off-by: zhaohu xing <920232796@qq.com> --- ldm/modules/losses/vqperceptual.py | 167 ------------------------------------- 1 file changed, 167 deletions(-) delete mode 100644 ldm/modules/losses/vqperceptual.py (limited to 'ldm/modules/losses/vqperceptual.py') diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py deleted file mode 100644 index f6998176..00000000 --- a/ldm/modules/losses/vqperceptual.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from einops import repeat - -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init -from taming.modules.losses.lpips import LPIPS -from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss - - -def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): - assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) - loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) - loss_real = (weights * loss_real).sum() / weights.sum() - loss_fake = (weights * loss_fake).sum() / weights.sum() - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - -def adopt_weight(weight, global_step, threshold=0, value=0.): - if global_step < threshold: - weight = value - return weight - - -def measure_perplexity(predicted_indices, n_embed): - # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py - # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) - avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = torch.sum(avg_probs > 0) - return perplexity, cluster_use - -def l1(x, y): - return torch.abs(x-y) - - -def l2(x, y): - return torch.pow((x-y), 2) - - -class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", - pixel_loss="l1"): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - assert perceptual_loss in ["lpips", "clips", "dists"] - assert pixel_loss in ["l1", "l2"] - self.codebook_weight = codebook_weight - self.pixel_weight = pixelloss_weight - if perceptual_loss == "lpips": - print(f"{self.__class__.__name__}: Running with LPIPS.") - self.perceptual_loss = LPIPS().eval() - else: - raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") - self.perceptual_weight = perceptual_weight - - if pixel_loss == "l1": - self.pixel_loss = l1 - else: - self.pixel_loss = l2 - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) - self.discriminator_iter_start = disc_start - if disc_loss == "hinge": - self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": - self.disc_loss = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - self.n_classes = n_classes - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", predicted_indices=None): - if not exists(codebook_loss): - codebook_loss = torch.tensor([0.]).to(inputs.device) - #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - else: - p_loss = torch.tensor([0.0]) - - nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - if predicted_indices is not None: - assert self.n_classes is not None - with torch.no_grad(): - perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) - log[f"{split}/perplexity"] = perplexity - log[f"{split}/cluster_usage"] = cluster_usage - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log -- cgit v1.2.1