Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Utility module to handle adversarial losses without requiring to mess up the main training loop. | |
| """ | |
| import typing as tp | |
| import flashy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2'] | |
| AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]] | |
| FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] | |
| class AdversarialLoss(nn.Module): | |
| """Adversary training wrapper. | |
| Args: | |
| adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples. | |
| We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]`` | |
| where the first item is a list of logits and the second item is a list of feature maps. | |
| optimizer (torch.optim.Optimizer): Optimizer used for training the given module. | |
| loss (AdvLossType): Loss function for generator training. | |
| loss_real (AdvLossType): Loss function for adversarial training on logits from real samples. | |
| loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples. | |
| loss_feat (FeatLossType): Feature matching loss function for generator training. | |
| normalize (bool): Whether to normalize by number of sub-discriminators. | |
| Example of usage: | |
| adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake) | |
| for real in loader: | |
| noise = torch.randn(...) | |
| fake = model(noise) | |
| adv_loss.train_adv(fake, real) | |
| loss, _ = adv_loss(fake, real) | |
| loss.backward() | |
| """ | |
| def __init__(self, | |
| adversary: nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| loss: AdvLossType, | |
| loss_real: AdvLossType, | |
| loss_fake: AdvLossType, | |
| loss_feat: tp.Optional[FeatLossType] = None, | |
| normalize: bool = True): | |
| super().__init__() | |
| self.adversary: nn.Module = adversary | |
| flashy.distrib.broadcast_model(self.adversary) | |
| self.optimizer = optimizer | |
| self.loss = loss | |
| self.loss_real = loss_real | |
| self.loss_fake = loss_fake | |
| self.loss_feat = loss_feat | |
| self.normalize = normalize | |
| def _save_to_state_dict(self, destination, prefix, keep_vars): | |
| # Add the optimizer state dict inside our own. | |
| super()._save_to_state_dict(destination, prefix, keep_vars) | |
| destination[prefix + 'optimizer'] = self.optimizer.state_dict() | |
| return destination | |
| def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | |
| # Load optimizer state. | |
| self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer')) | |
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
| def get_adversary_pred(self, x): | |
| """Run adversary model, validating expected output format.""" | |
| logits, fmaps = self.adversary(x) | |
| assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \ | |
| f'Expecting a list of tensors as logits but {type(logits)} found.' | |
| assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.' | |
| for fmap in fmaps: | |
| assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \ | |
| f'Expecting a list of tensors as feature maps but {type(fmap)} found.' | |
| return logits, fmaps | |
| def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: | |
| """Train the adversary with the given fake and real example. | |
| We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]]. | |
| The first item being the logits and second item being a list of feature maps for each sub-discriminator. | |
| This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`) | |
| and call the optimizer. | |
| """ | |
| loss = torch.tensor(0., device=fake.device) | |
| all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach()) | |
| all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach()) | |
| n_sub_adversaries = len(all_logits_fake_is_fake) | |
| for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake): | |
| loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake) | |
| if self.normalize: | |
| loss /= n_sub_adversaries | |
| self.optimizer.zero_grad() | |
| with flashy.distrib.eager_sync_model(self.adversary): | |
| loss.backward() | |
| self.optimizer.step() | |
| return loss | |
| def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
| """Return the loss for the generator, i.e. trying to fool the adversary, | |
| and feature matching loss if provided. | |
| """ | |
| adv = torch.tensor(0., device=fake.device) | |
| feat = torch.tensor(0., device=fake.device) | |
| with flashy.utils.readonly(self.adversary): | |
| all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake) | |
| all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real) | |
| n_sub_adversaries = len(all_logits_fake_is_fake) | |
| for logit_fake_is_fake in all_logits_fake_is_fake: | |
| adv += self.loss(logit_fake_is_fake) | |
| if self.loss_feat: | |
| for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real): | |
| feat += self.loss_feat(fmap_fake, fmap_real) | |
| if self.normalize: | |
| adv /= n_sub_adversaries | |
| feat /= n_sub_adversaries | |
| return adv, feat | |
| def get_adv_criterion(loss_type: str) -> tp.Callable: | |
| assert loss_type in ADVERSARIAL_LOSSES | |
| if loss_type == 'mse': | |
| return mse_loss | |
| elif loss_type == 'hinge': | |
| return hinge_loss | |
| elif loss_type == 'hinge2': | |
| return hinge2_loss | |
| raise ValueError('Unsupported loss') | |
| def get_fake_criterion(loss_type: str) -> tp.Callable: | |
| assert loss_type in ADVERSARIAL_LOSSES | |
| if loss_type == 'mse': | |
| return mse_fake_loss | |
| elif loss_type in ['hinge', 'hinge2']: | |
| return hinge_fake_loss | |
| raise ValueError('Unsupported loss') | |
| def get_real_criterion(loss_type: str) -> tp.Callable: | |
| assert loss_type in ADVERSARIAL_LOSSES | |
| if loss_type == 'mse': | |
| return mse_real_loss | |
| elif loss_type in ['hinge', 'hinge2']: | |
| return hinge_real_loss | |
| raise ValueError('Unsupported loss') | |
| def mse_real_loss(x: torch.Tensor) -> torch.Tensor: | |
| return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) | |
| def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: | |
| return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x)) | |
| def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: | |
| return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) | |
| def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: | |
| return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x))) | |
| def mse_loss(x: torch.Tensor) -> torch.Tensor: | |
| if x.numel() == 0: | |
| return torch.tensor([0.0], device=x.device) | |
| return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) | |
| def hinge_loss(x: torch.Tensor) -> torch.Tensor: | |
| if x.numel() == 0: | |
| return torch.tensor([0.0], device=x.device) | |
| return -x.mean() | |
| def hinge2_loss(x: torch.Tensor) -> torch.Tensor: | |
| if x.numel() == 0: | |
| return torch.tensor([0.0]) | |
| return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) | |
| class FeatureMatchingLoss(nn.Module): | |
| """Feature matching loss for adversarial training. | |
| Args: | |
| loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1). | |
| normalize (bool): Whether to normalize the loss. | |
| by number of feature maps. | |
| """ | |
| def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True): | |
| super().__init__() | |
| self.loss = loss | |
| self.normalize = normalize | |
| def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor: | |
| assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0 | |
| feat_loss = torch.tensor(0., device=fmap_fake[0].device) | |
| feat_scale = torch.tensor(0., device=fmap_fake[0].device) | |
| n_fmaps = 0 | |
| for (feat_fake, feat_real) in zip(fmap_fake, fmap_real): | |
| assert feat_fake.shape == feat_real.shape | |
| n_fmaps += 1 | |
| feat_loss += self.loss(feat_fake, feat_real) | |
| feat_scale += torch.mean(torch.abs(feat_real)) | |
| if self.normalize: | |
| feat_loss /= n_fmaps | |
| return feat_loss | |