Spaces:
Runtime error
Runtime error
File size: 720 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn as nn
class GaussianKLLoss(nn.Module):
r"""Compute KL loss in VAE for Gaussian distributions"""
def __init__(self):
super(GaussianKLLoss, self).__init__()
def forward(self, mu, logvar=None):
r"""Compute loss
Args:
mu (tensor): mean
logvar (tensor): logarithm of variance
"""
if logvar is None:
logvar = torch.zeros_like(mu)
return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
|