File size: 3,507 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Tuple, Union
import numpy as np
import torch
from torch_scatter import scatter_max, scatter_add, scatter_mean
from src.layers.loss_fill_space_torch import LLFillSpace
import dgl


def infonet_updated(g, qmin, xj, bj):
    list_graphs = dgl.unbatch(g)
    loss_total = 0
    Loss_beta_zero = 0
    number_particles_accounted_for = 0
    node_counter = 0
    Loss_beta = 0
    for i in range(0, len(list_graphs)):
        graph_eval = list_graphs[i]
        non = graph_eval.number_of_nodes()
        xj_graph = xj[node_counter : non + node_counter]
        bj_graph = bj[node_counter : non + node_counter]
        q_graph = bj_graph.arctanh() ** 2 + qmin
        q = q_graph.detach().cpu().numpy()
        part_num = graph_eval.ndata["particle_number"].view(-1).to(torch.long)
        q_alpha, index_alpha = scatter_max(q_graph.view(-1), part_num - 1)
        x_alpha = xj_graph[index_alpha]
        number_of_particles = torch.unique(graph_eval.ndata["particle_number"])
        indx = torch.zeros((len(number_of_particles), 50)).to(x_alpha.device)
        b_alpha = bj_graph[index_alpha]
        # beta_zero_loss = torch.sum(torch.exp(10 * bj_graph)) / non
        beta_zero_loss = torch.sum(bj_graph) / non
        if len(number_of_particles) > 1:
            for nn in range(0, len(number_of_particles)):
                idx_part = number_of_particles[nn]
                positives_of_class = graph_eval.ndata["particle_number"] == idx_part
                pos_indx = torch.where(positives_of_class == True)[0]
                if len(pos_indx) > 50:
                    indx[nn, :] = pos_indx[0:50]
                else:
                    indx[nn, 0 : len(pos_indx)] = pos_indx
                    indx[nn, len(pos_indx) :] = pos_indx[0]
            for nn in range(0, len(number_of_particles)):
                idx_part = number_of_particles[nn]
                positives_of_class = graph_eval.ndata["particle_number"] == idx_part
                xj_ = xj_graph[positives_of_class]
                x_alpha_ = x_alpha[nn]
                dot_products = torch.mul(
                    xj_, x_alpha_.unsqueeze(0).tile((xj_.shape[0], 1))
                ).sum(dim=1)
                dot_products_exp = torch.exp(dot_products)
                indx_copy = indx.clone()
                indx_copy[nn] = -1
                indx_copy = indx_copy.view(-1)
                indx_copy = indx_copy[indx_copy > -1]
                neg_indx = indx_copy
                xj_neg = xj_graph[neg_indx.long()]
                dot_neg = torch.sum(
                    torch.exp(torch.tensordot(xj_, xj_neg, dims=([1], [1]))), dim=1
                )
                loss = torch.mean(-torch.log(dot_products_exp / dot_neg))
                if torch.sum(torch.isnan(loss)) > 0:
                    print(dot_products)
                    print(dot_neg)
                    print(dot_products / dot_neg)
                    print(torch.log(dot_products / dot_neg))
                loss_total = loss_total + loss
                number_particles_accounted_for = number_particles_accounted_for + 1
        Loss_beta = Loss_beta + torch.sum((1 - b_alpha)) / len(b_alpha)
        Loss_beta_zero = Loss_beta_zero + beta_zero_loss

    loss_total = loss_total / number_particles_accounted_for
    Loss_beta = Loss_beta / len(list_graphs)
    Loss_beta_zero = Loss_beta_zero / len(list_graphs)
    loss_total_ = loss_total + Loss_beta + Loss_beta_zero

    return loss_total_, Loss_beta, Loss_beta_zero, loss_total