Spaces:
Runtime error
Runtime error
File size: 7,285 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from imaginaire.utils.distributed import master_only_print as print
@torch.jit.script
def fuse_math_min_mean_pos(x):
r"""Fuse operation min mean for hinge loss computation of positive
samples"""
minval = torch.min(x - 1, x * 0)
loss = -torch.mean(minval)
return loss
@torch.jit.script
def fuse_math_min_mean_neg(x):
r"""Fuse operation min mean for hinge loss computation of negative
samples"""
minval = torch.min(-x - 1, x * 0)
loss = -torch.mean(minval)
return loss
class GANLoss(nn.Module):
r"""GAN loss constructor.
Args:
gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``,
``'non_saturated'``, ``'wasserstein'``.
target_real_label (float): The desired output label for real images.
target_fake_label (float): The desired output label for fake images.
decay_k (float): The decay factor per epoch for top-k training.
min_k (float): The minimum percentage of samples to select.
separate_topk (bool): If ``True``, selects top-k for each sample
separately, otherwise selects top-k among all samples.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
decay_k=1., min_k=1., separate_topk=False):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_tensor = None
self.fake_label_tensor = None
self.gan_mode = gan_mode
self.decay_k = decay_k
self.min_k = min_k
self.separate_topk = separate_topk
self.register_buffer('k', torch.tensor(1.0))
print('GAN mode: %s' % gan_mode)
def forward(self, dis_output, t_real, dis_update=True, reduce=True):
r"""GAN loss computation.
Args:
dis_output (tensor or list of tensors): Discriminator outputs.
t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target.
dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator.
reduce (bool): If ``True``, when a list of discriminator outputs are provided, it will return the average
of all losses, otherwise it will return a list of losses.
Returns:
loss (tensor): Loss value.
"""
if isinstance(dis_output, list):
# For multi-scale discriminators.
# In this implementation, the loss is first averaged for each scale
# (batch size and number of locations) then averaged across scales,
# so that the gradient is not dominated by the discriminator that
# has the most output values (highest resolution).
losses = []
for dis_output_i in dis_output:
assert isinstance(dis_output_i, torch.Tensor)
losses.append(self.loss(dis_output_i, t_real, dis_update))
if reduce:
return torch.mean(torch.stack(losses))
else:
return losses
else:
return self.loss(dis_output, t_real, dis_update)
def loss(self, dis_output, t_real, dis_update=True):
r"""GAN loss computation.
Args:
dis_output (tensor): Discriminator outputs.
t_real (bool): If ``True``, uses the real label as target, otherwise
uses the fake label as target.
dis_update (bool): Updating the discriminator or the generator.
Returns:
loss (tensor): Loss value.
"""
if not dis_update:
assert t_real, \
"The target should be real when updating the generator."
if not dis_update and self.k < 1:
r"""
Use top-k training:
"Top-k Training of GANs: Improving GAN Performance by Throwing
Away Bad Samples"
Here, each sample may have multiple discriminator output values
(patch discriminator). We could either select top-k for each sample
separately (when ``self.separate_topk=True``), or collect values
from all samples and then select top-k (default, when
``self.separate_topk=False``).
"""
if self.separate_topk:
dis_output = dis_output.view(dis_output.size(0), -1)
else:
dis_output = dis_output.view(-1)
k = math.ceil(self.k * dis_output.size(-1))
dis_output, _ = torch.topk(dis_output, k)
if self.gan_mode == 'non_saturated':
target_tensor = self.get_target_tensor(dis_output, t_real)
loss = F.binary_cross_entropy_with_logits(dis_output,
target_tensor)
elif self.gan_mode == 'least_square':
target_tensor = self.get_target_tensor(dis_output, t_real)
loss = 0.5 * F.mse_loss(dis_output, target_tensor)
elif self.gan_mode == 'hinge':
if dis_update:
if t_real:
loss = fuse_math_min_mean_pos(dis_output)
else:
loss = fuse_math_min_mean_neg(dis_output)
else:
loss = -torch.mean(dis_output)
elif self.gan_mode == 'wasserstein':
if t_real:
loss = -torch.mean(dis_output)
else:
loss = torch.mean(dis_output)
elif self.gan_mode == 'softplus':
target_tensor = self.get_target_tensor(dis_output, t_real)
loss = F.binary_cross_entropy_with_logits(dis_output,
target_tensor)
else:
raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode))
return loss
def get_target_tensor(self, dis_output, t_real):
r"""Return the target vector for the binary cross entropy loss
computation.
Args:
dis_output (tensor): Discriminator outputs.
t_real (bool): If ``True``, uses the real label as target, otherwise
uses the fake label as target.
Returns:
target (tensor): Target tensor vector.
"""
if t_real:
if self.real_label_tensor is None:
self.real_label_tensor = dis_output.new_tensor(self.real_label)
return self.real_label_tensor.expand_as(dis_output)
else:
if self.fake_label_tensor is None:
self.fake_label_tensor = dis_output.new_tensor(self.fake_label)
return self.fake_label_tensor.expand_as(dis_output)
def topk_anneal(self):
r"""Anneal k after each epoch."""
if self.decay_k < 1:
# noinspection PyAttributeOutsideInit
self.k.fill_(max(self.decay_k * self.k, self.min_k))
print("Top-k training: update k to {}.".format(self.k))
|