Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import torch | |
| import torch.nn as nn | |
| class GaussianKLLoss(nn.Module): | |
| r"""Compute KL loss in VAE for Gaussian distributions""" | |
| def __init__(self): | |
| super(GaussianKLLoss, self).__init__() | |
| def forward(self, mu, logvar=None): | |
| r"""Compute loss | |
| Args: | |
| mu (tensor): mean | |
| logvar (tensor): logarithm of variance | |
| """ | |
| if logvar is None: | |
| logvar = torch.zeros_like(mu) | |
| return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) | |