bndl's picture
Upload 115 files
4f5540c
import torch
import numpy as np
def NIG_NLL(y, gamma, v, alpha, beta, reduce=True):
'''
Negative log loss for Normal-Inverse Gamma distribution
For learning uncertainties with evidential regression
'''
O = 2 * beta * (1 + v)
nll = 0.5 * torch.log(np.pi/v) \
- alpha * torch.log(O) \
+ (alpha + 0.5) * (torch.log(v * (y - gamma)**2 + O)) \
+ torch.lgamma(alpha) \
- torch.lgamma(alpha + 0.5)
return torch.mean(nll) if reduce else nll
def NIG_reg(y, gamma, v, alpha, reduce = True, *_):
'''
Computes regularization loss for evidential regression
'''
Phi = (2 * v + alpha)
L = (torch.abs(y - gamma) * Phi)
return torch.mean(L) if reduce else L
def evidential_loss(y_true, output_dict, coef = 1.0, reduce = True):
'''
Entire loss function for evidential regression
'''
gamma = output_dict['gamma']
v = output_dict['v']
alpha = output_dict['alpha']
beta = output_dict['beta']
loss_nll = NIG_NLL(y_true, gamma, v, alpha, beta, reduce)
loss_reg = NIG_reg(y_true, gamma, v, alpha, reduce = reduce)
return loss_nll + coef * loss_reg