|
import torch |
|
import numpy as np |
|
|
|
def NIG_NLL(y, gamma, v, alpha, beta, reduce=True): |
|
''' |
|
Negative log loss for Normal-Inverse Gamma distribution |
|
|
|
For learning uncertainties with evidential regression |
|
''' |
|
|
|
O = 2 * beta * (1 + v) |
|
|
|
nll = 0.5 * torch.log(np.pi/v) \ |
|
- alpha * torch.log(O) \ |
|
+ (alpha + 0.5) * (torch.log(v * (y - gamma)**2 + O)) \ |
|
+ torch.lgamma(alpha) \ |
|
- torch.lgamma(alpha + 0.5) |
|
|
|
return torch.mean(nll) if reduce else nll |
|
|
|
def NIG_reg(y, gamma, v, alpha, reduce = True, *_): |
|
''' |
|
Computes regularization loss for evidential regression |
|
''' |
|
Phi = (2 * v + alpha) |
|
L = (torch.abs(y - gamma) * Phi) |
|
return torch.mean(L) if reduce else L |
|
|
|
def evidential_loss(y_true, output_dict, coef = 1.0, reduce = True): |
|
''' |
|
Entire loss function for evidential regression |
|
''' |
|
|
|
gamma = output_dict['gamma'] |
|
v = output_dict['v'] |
|
alpha = output_dict['alpha'] |
|
beta = output_dict['beta'] |
|
|
|
loss_nll = NIG_NLL(y_true, gamma, v, alpha, beta, reduce) |
|
loss_reg = NIG_reg(y_true, gamma, v, alpha, reduce = reduce) |
|
|
|
return loss_nll + coef * loss_reg |