jetclustering / src /layers /object_cond_infonet.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
3.51 kB
from typing import Tuple, Union
import numpy as np
import torch
from torch_scatter import scatter_max, scatter_add, scatter_mean
from src.layers.loss_fill_space_torch import LLFillSpace
import dgl
def infonet_updated(g, qmin, xj, bj):
list_graphs = dgl.unbatch(g)
loss_total = 0
Loss_beta_zero = 0
number_particles_accounted_for = 0
node_counter = 0
Loss_beta = 0
for i in range(0, len(list_graphs)):
graph_eval = list_graphs[i]
non = graph_eval.number_of_nodes()
xj_graph = xj[node_counter : non + node_counter]
bj_graph = bj[node_counter : non + node_counter]
q_graph = bj_graph.arctanh() ** 2 + qmin
q = q_graph.detach().cpu().numpy()
part_num = graph_eval.ndata["particle_number"].view(-1).to(torch.long)
q_alpha, index_alpha = scatter_max(q_graph.view(-1), part_num - 1)
x_alpha = xj_graph[index_alpha]
number_of_particles = torch.unique(graph_eval.ndata["particle_number"])
indx = torch.zeros((len(number_of_particles), 50)).to(x_alpha.device)
b_alpha = bj_graph[index_alpha]
# beta_zero_loss = torch.sum(torch.exp(10 * bj_graph)) / non
beta_zero_loss = torch.sum(bj_graph) / non
if len(number_of_particles) > 1:
for nn in range(0, len(number_of_particles)):
idx_part = number_of_particles[nn]
positives_of_class = graph_eval.ndata["particle_number"] == idx_part
pos_indx = torch.where(positives_of_class == True)[0]
if len(pos_indx) > 50:
indx[nn, :] = pos_indx[0:50]
else:
indx[nn, 0 : len(pos_indx)] = pos_indx
indx[nn, len(pos_indx) :] = pos_indx[0]
for nn in range(0, len(number_of_particles)):
idx_part = number_of_particles[nn]
positives_of_class = graph_eval.ndata["particle_number"] == idx_part
xj_ = xj_graph[positives_of_class]
x_alpha_ = x_alpha[nn]
dot_products = torch.mul(
xj_, x_alpha_.unsqueeze(0).tile((xj_.shape[0], 1))
).sum(dim=1)
dot_products_exp = torch.exp(dot_products)
indx_copy = indx.clone()
indx_copy[nn] = -1
indx_copy = indx_copy.view(-1)
indx_copy = indx_copy[indx_copy > -1]
neg_indx = indx_copy
xj_neg = xj_graph[neg_indx.long()]
dot_neg = torch.sum(
torch.exp(torch.tensordot(xj_, xj_neg, dims=([1], [1]))), dim=1
)
loss = torch.mean(-torch.log(dot_products_exp / dot_neg))
if torch.sum(torch.isnan(loss)) > 0:
print(dot_products)
print(dot_neg)
print(dot_products / dot_neg)
print(torch.log(dot_products / dot_neg))
loss_total = loss_total + loss
number_particles_accounted_for = number_particles_accounted_for + 1
Loss_beta = Loss_beta + torch.sum((1 - b_alpha)) / len(b_alpha)
Loss_beta_zero = Loss_beta_zero + beta_zero_loss
loss_total = loss_total / number_particles_accounted_for
Loss_beta = Loss_beta / len(list_graphs)
Loss_beta_zero = Loss_beta_zero / len(list_graphs)
loss_total_ = loss_total + Loss_beta + Loss_beta_zero
return loss_total_, Loss_beta, Loss_beta_zero, loss_total