491eded
1
2
3
4
5
6
7
8
9
10
11
import torch.nn as nn from . import functional as F __all__ = ['KLLoss'] class KLLoss(nn.Module): def forward(self, x, y): return F.kl_loss(x, y)