File size: 1,157 Bytes
4f5540c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import numpy as np

def NIG_NLL(y, gamma, v, alpha, beta, reduce=True):
    '''
    Negative log loss for Normal-Inverse Gamma distribution

    For learning uncertainties with evidential regression
    '''

    O = 2 * beta * (1 + v)

    nll = 0.5 * torch.log(np.pi/v) \
        - alpha * torch.log(O) \
        + (alpha + 0.5) * (torch.log(v * (y - gamma)**2 + O)) \
        + torch.lgamma(alpha) \
        - torch.lgamma(alpha + 0.5)

    return torch.mean(nll) if reduce else nll

def NIG_reg(y, gamma, v, alpha, reduce = True, *_):
    '''
    Computes regularization loss for evidential regression
    '''
    Phi = (2 * v + alpha)
    L = (torch.abs(y - gamma) * Phi)
    return torch.mean(L) if reduce else L

def evidential_loss(y_true, output_dict, coef = 1.0, reduce = True):
    '''
    Entire loss function for evidential regression
    '''

    gamma = output_dict['gamma']
    v = output_dict['v']
    alpha = output_dict['alpha']
    beta = output_dict['beta']

    loss_nll = NIG_NLL(y_true, gamma, v, alpha, beta, reduce)
    loss_reg = NIG_reg(y_true, gamma, v, alpha, reduce = reduce)

    return loss_nll + coef * loss_reg