Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torch_scatter import scatter_min, scatter_max, scatter_mean, scatter_add | |
from src.layers.GravNetConv import GravNetConv | |
from typing import Tuple, Union, List | |
import dgl | |
onehot_particles_arr = [ | |
-2212.0, | |
-211.0, | |
-14.0, | |
-13.0, | |
-11.0, | |
11.0, | |
12.0, | |
13.0, | |
14.0, | |
22.0, | |
111.0, | |
130.0, | |
211.0, | |
2112.0, | |
2212.0, | |
1000010048.0, | |
1000020032.0, | |
1000040064.0, | |
1000050112.0, | |
1000060096.0, | |
1000080128.0, | |
] | |
onehot_particles_arr = [int(x) for x in onehot_particles_arr] | |
pid_dict = {i + 1: onehot_particles_arr[i] for i in range(len(onehot_particles_arr))} | |
pid_dict[0] = "other" | |
def safe_index(arr, index): | |
# One-hot index (or zero if it's not in the array) | |
if index not in arr: | |
return 0 | |
else: | |
return arr.index(index) + 1 | |
def assert_no_nans(x): | |
""" | |
Raises AssertionError if there is a nan in the tensor | |
""" | |
if torch.isnan(x).any(): | |
print(x) | |
assert not torch.isnan(x).any() | |
# FIXME: Use a logger instead of this | |
DEBUG = False | |
def debug(*args, **kwargs): | |
if DEBUG: | |
print(*args, **kwargs) | |
def calc_energy_pred( | |
batch, | |
g, | |
cluster_index_per_event, | |
is_sig, | |
q, | |
beta, | |
energy_correction, | |
pid_results, | |
hit_mom, | |
): | |
td = 0.7 | |
batch_number = torch.max(batch) + 1 | |
energies = [] | |
pid_outputs = [] | |
momenta = [] | |
for i in range(0, batch_number): | |
mask_batch = batch == i | |
X = g.ndata["pos_hits_xyz"][mask_batch] | |
cluster_index_i = cluster_index_per_event[mask_batch] - 1 | |
is_sig_i = is_sig[mask_batch] | |
q_i = q[mask_batch] | |
betas = beta[mask_batch] | |
q_alpha_i, index_alpha_i = scatter_max(q_i[is_sig_i], cluster_index_i) | |
n_points = betas.size(0) | |
unassigned = torch.arange(n_points).to(betas.device) | |
clustering = -1 * torch.ones(n_points, dtype=torch.long) | |
counter = 0 | |
# index_alpha_i -= 1 | |
for index_condpoint in index_alpha_i: | |
d = torch.norm(X[unassigned] - X[index_condpoint], dim=-1) | |
assigned_to_this_condpoint = unassigned[d < td] | |
clustering[assigned_to_this_condpoint] = counter | |
unassigned = unassigned[~(d < td)] | |
counter = counter + 1 | |
counter = 0 | |
for index_condpoint in index_alpha_i: | |
clustering[index_condpoint] = counter | |
counter = counter + 1 | |
if torch.sum(clustering == -1) > 0: | |
clustering_ = clustering + 1 | |
else: | |
clustering_ = clustering | |
clus_values = np.unique(clustering) | |
e_c = g.ndata["e_hits"][mask_batch][is_sig_i].view(-1) * energy_correction[ | |
mask_batch | |
][is_sig_i].view(-1) | |
mom_c = hit_mom[mask_batch][is_sig_i].view(-1) | |
# pid_results_i = pid_results[mask_batch][is_sig_i][index_alpha_i] | |
pid_results_i = scatter_add( | |
pid_results[mask_batch][is_sig_i], | |
clustering_.long().to(pid_results.device), | |
dim=0, | |
) | |
# aggregated "PID embeddings" | |
e_objects = scatter_add(e_c, clustering_.long().to(e_c.device)) | |
mom_objects = scatter_add(mom_c, clustering_.long().to(mom_c.device)) | |
e_objects = e_objects[clus_values != -1] | |
pid_results_i = pid_results_i[clus_values != -1] | |
mom_objects = mom_objects[clus_values != -1] | |
energies.append(e_objects) | |
pid_outputs.append(pid_results_i) | |
momenta.append(mom_objects) | |
return ( | |
torch.cat(energies, dim=0), | |
torch.cat(pid_outputs, dim=0), | |
torch.cat(momenta, dim=0), | |
) | |
def calc_pred_pid(batch, g, cluster_index_per_event, is_sig, q, beta, pred_pid): | |
outputs = [] | |
batch_number = torch.max(batch) + 1 | |
for i in range(0, batch_number): | |
mask_batch = batch == i | |
is_sig_i = is_sig[mask_batch] | |
pid = pred_pid[mask_batch][is_sig_i].view(-1) | |
outputs.append(pid) | |
return torch.cat(outputs, dim=0) | |
def calc_LV_Lbeta( | |
original_coords, | |
g, | |
y, | |
distance_threshold, | |
energy_correction, | |
momentum: torch.Tensor, | |
beta: torch.Tensor, | |
cluster_space_coords: torch.Tensor, # Predicted by model | |
cluster_index_per_event: torch.Tensor, # Truth hit->cluster index | |
batch: torch.Tensor, | |
predicted_pid: torch.Tensor, # predicted PID embeddings - will be aggregated by summing up the clusters and applying the post_pid_pool_module MLP afterwards | |
post_pid_pool_module: None, # MLP to apply to the pooled embeddings to get the PID predictions torch.nn.Module | |
# From here on just parameters | |
qmin: float = 0.1, | |
s_B: float = 1.0, | |
noise_cluster_index: int = 0, # cluster_index entries with this value are noise/noise | |
beta_stabilizing="soft_q_scaling", | |
huberize_norm_for_V_attractive=False, | |
beta_term_option="paper", | |
return_components=False, | |
return_regression_resolution=False, | |
clust_space_dim=3, | |
frac_combinations=0, # fraction of the all possible pairs to be used for the clustering loss | |
attr_weight=1.0, | |
repul_weight=1.0, | |
fill_loss_weight=0.0, | |
use_average_cc_pos=0.0, | |
hgcal_implementation=False, | |
hit_energies=None, | |
tracking=False, | |
dis = False | |
) -> Union[Tuple[torch.Tensor, torch.Tensor], dict]: | |
""" | |
Calculates the L_V and L_beta object condensation losses. | |
Concepts: | |
- A hit belongs to exactly one cluster (cluster_index_per_event is (n_hits,)), | |
and to exactly one event (batch is (n_hits,)) | |
- A cluster index of `noise_cluster_index` means the cluster is a noise cluster. | |
There is typically one noise cluster per event. Any hit in a noise cluster | |
is a 'noise hit'. A hit in an object is called a 'signal hit' for lack of a | |
better term. | |
- An 'object' is a cluster that is *not* a noise cluster. | |
beta_stabilizing: Choices are ['paper', 'clip', 'soft_q_scaling']: | |
paper: beta is sigmoid(model_output), q = beta.arctanh()**2 + qmin | |
clip: beta is clipped to 1-1e-4, q = beta.arctanh()**2 + qmin | |
soft_q_scaling: beta is sigmoid(model_output), q = (clip(beta)/1.002).arctanh()**2 + qmin | |
huberize_norm_for_V_attractive: Huberizes the norms when used in the attractive potential | |
beta_term_option: Choices are ['paper', 'short-range-potential']: | |
Choosing 'short-range-potential' introduces a short range potential around high | |
beta points, acting like V_attractive. | |
Note this function has modifications w.r.t. the implementation in 2002.03605: | |
- The norms for V_repulsive are now Gaussian (instead of linear hinge) | |
""" | |
# remove dummy rows added for dataloader #TODO think of better way to do this | |
device = beta.device | |
if torch.isnan(beta).any(): | |
print("There are nans in beta! L198", len(beta[torch.isnan(beta)])) | |
beta = torch.nan_to_num(beta, nan=0.0) | |
assert_no_nans(beta) | |
# ________________________________ | |
# Calculate a bunch of needed counts and indices locally | |
# cluster_index: unique index over events | |
# E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1] | |
# -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ] | |
cluster_index, n_clusters_per_event = batch_cluster_indices( | |
cluster_index_per_event, batch | |
) | |
n_clusters = n_clusters_per_event.sum() | |
n_hits, cluster_space_dim = cluster_space_coords.size() | |
batch_size = batch.max() + 1 | |
n_hits_per_event = scatter_count(batch) | |
# Index of cluster -> event (n_clusters,) | |
batch_cluster = scatter_counts_to_indices(n_clusters_per_event) | |
# Per-hit boolean, indicating whether hit is sig or noise | |
is_noise = cluster_index_per_event == noise_cluster_index | |
is_sig = ~is_noise | |
n_hits_sig = is_sig.sum() | |
n_sig_hits_per_event = scatter_count(batch[is_sig]) | |
# Per-cluster boolean, indicating whether cluster is an object or noise | |
is_object = scatter_max(is_sig.long(), cluster_index)[0].bool() | |
is_noise_cluster = ~is_object | |
# FIXME: This assumes noise_cluster_index == 0!! | |
# Not sure how to do this in a performant way in case noise_cluster_index != 0 | |
if noise_cluster_index != 0: | |
raise NotImplementedError | |
object_index_per_event = cluster_index_per_event[is_sig] - 1 | |
object_index, n_objects_per_event = batch_cluster_indices( | |
object_index_per_event, batch[is_sig] | |
) | |
n_hits_per_object = scatter_count(object_index) | |
# print("n_hits_per_object", n_hits_per_object) | |
batch_object = batch_cluster[is_object] | |
n_objects = is_object.sum() | |
assert object_index.size() == (n_hits_sig,) | |
assert is_object.size() == (n_clusters,) | |
assert torch.all(n_hits_per_object > 0) | |
assert object_index.max() + 1 == n_objects | |
# ________________________________ | |
# L_V term | |
# Calculate q | |
if hgcal_implementation: | |
q = (beta.arctanh() / 1.01) ** 2 + qmin | |
elif beta_stabilizing == "paper": | |
q = beta.arctanh() ** 2 + qmin | |
elif beta_stabilizing == "clip": | |
beta = beta.clip(0.0, 1 - 1e-4) | |
q = beta.arctanh() ** 2 + qmin | |
elif beta_stabilizing == "soft_q_scaling": | |
q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin | |
else: | |
raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known") | |
assert_no_nans(q) | |
assert q.device == device | |
assert q.size() == (n_hits,) | |
# Calculate q_alpha, the max q per object, and the indices of said maxima | |
# assert hit_energies.shape == q.shape | |
# q_alpha, index_alpha = scatter_max(hit_energies[is_sig], object_index) | |
q_alpha, index_alpha = scatter_max(q[is_sig], object_index) | |
assert q_alpha.size() == (n_objects,) | |
# Get the cluster space coordinates and betas for these maxima hits too | |
x_alpha = cluster_space_coords[is_sig][index_alpha] | |
x_alpha_original = original_coords[is_sig][index_alpha] | |
if use_average_cc_pos > 0: | |
#! this is a func of beta and q so maybe we could also do it with only q | |
x_alpha_sum = scatter_add( | |
q[is_sig].view(-1, 1).repeat(1, 3) * cluster_space_coords[is_sig], | |
object_index, | |
dim=0, | |
) # * beta[is_sig].view(-1, 1).repeat(1, 3) | |
qbeta_alpha_sum = scatter_add(q[is_sig], object_index) + 1e-9 # * beta[is_sig] | |
div_fac = 1 / qbeta_alpha_sum | |
div_fac = torch.nan_to_num(div_fac, nan=0) | |
x_alpha_mean = torch.mul(x_alpha_sum, div_fac.view(-1, 1).repeat(1, 3)) | |
x_alpha = use_average_cc_pos * x_alpha_mean + (1 - use_average_cc_pos) * x_alpha | |
if dis: | |
phi_sum = scatter_add( | |
beta[is_sig].view(-1) * distance_threshold[is_sig].view(-1), | |
object_index, | |
dim=0, | |
) | |
phi_alpha_sum = scatter_add(beta[is_sig].view(-1), object_index) + 1e-9 | |
phi_alpha = phi_sum/phi_alpha_sum | |
beta_alpha = beta[is_sig][index_alpha] | |
assert x_alpha.size() == (n_objects, cluster_space_dim) | |
assert beta_alpha.size() == (n_objects,) | |
if not tracking: | |
positions_particles_pred = g.ndata["pos_hits_xyz"][is_sig][index_alpha] | |
positions_particles_pred = ( | |
positions_particles_pred + distance_threshold[is_sig][index_alpha] | |
) | |
# e_particles_pred = g.ndata["e_hits"][is_sig][index_alpha] | |
# e_particles_pred = e_particles_pred * energy_correction[is_sig][index_alpha] | |
# particles pred updated to follow end-to-end paper approach, sum the particles in the object and multiply by the correction factor of alpha (the cluster center) | |
# e_particles_pred = (scatter_add(g.ndata["e_hits"][is_sig].view(-1), object_index)*energy_correction[is_sig][index_alpha].view(-1)).view(-1,1) | |
e_particles_pred, pid_particles_pred, mom_particles_pred = calc_energy_pred( | |
batch, | |
g, | |
cluster_index_per_event, | |
is_sig, | |
q, | |
beta, | |
energy_correction, | |
predicted_pid, | |
momentum, | |
) | |
if fill_loss_weight > 0: | |
fill_loss = fill_loss_weight * LLFillSpace()(cluster_space_coords, batch) | |
else: | |
fill_loss = 0 | |
# pid_particles_pred = post_pid_pool_module( | |
# pid_particles_pred | |
# ) # Project the pooled PID embeddings to the final "one hot encoding" space | |
# pid_particles_pred = calc_pred_pid( | |
# batch, g, cluster_index_per_event, is_sig, q, beta, predicted_pid | |
# ) | |
if not tracking: | |
x_particles = y[:, 0:3] | |
e_particles = y[:, 3] | |
mom_particles_true = y[:, 4] | |
mass_particles_true = y[:, 5] | |
# particles_mask = y[:, 6] | |
mom_particles_true = mom_particles_true.to(device) | |
mass_particles_pred = e_particles_pred**2 - mom_particles_pred**2 | |
mass_particles_true = mass_particles_true.to(device) | |
mass_particles_pred[mass_particles_pred < 0] = 0.0 | |
mass_particles_pred = torch.sqrt(mass_particles_pred) | |
loss_mass = torch.nn.MSELoss()( | |
mass_particles_true, mass_particles_pred | |
) # only logging this, not using it in the loss func | |
pid_id_particles = y[:, 6].unsqueeze(1).long() | |
pid_particles_true = torch.zeros((pid_id_particles.shape[0], 22)) | |
part_idx_onehot = [ | |
safe_index(onehot_particles_arr, i) | |
for i in pid_id_particles.flatten().tolist() | |
] | |
pid_particles_true[ | |
torch.arange(pid_id_particles.shape[0]), part_idx_onehot | |
] = 1.0 | |
# if return_regression_resolution: | |
# e_particles_pred = e_particles_pred.detach().flatten() | |
# e_particles = e_particles.detach().flatten() | |
# positions_particles_pred = positions_particles_pred.detach().flatten() | |
# x_particles = x_particles.detach().flatten() | |
# mom_particles_pred = mom_particles_pred.detach().flatten().to("cpu") | |
# mom_particles_true = mom_particles_true.detach().flatten().to("cpu") | |
# return ( | |
# { | |
# "momentum_res": ( | |
# (mom_particles_pred - mom_particles_true) / mom_particles_true | |
# ).tolist(), | |
# "e_res": ((e_particles_pred - e_particles) / e_particles).tolist(), | |
# "pos_res": ( | |
# (positions_particles_pred - x_particles) / x_particles | |
# ).tolist(), | |
# }, | |
# pid_particles_true, | |
# pid_particles_pred, | |
# ) | |
e_particles_pred_per_object = scatter_add( | |
g.ndata["e_hits"][is_sig].view(-1), object_index | |
) # *energy_correction[is_sig][index_alpha].view(-1)).view(-1,1) | |
e_particle_pred_per_particle = e_particles_pred_per_object[ | |
object_index | |
] * energy_correction.view(-1) | |
e_true = y[:, 3].clone() | |
e_true = e_true.to(e_particles_pred_per_object.device) | |
e_true_particle = e_true[object_index] | |
L_i = (e_particle_pred_per_particle - e_true_particle) ** 2 / e_true_particle | |
B_i = (beta[is_sig].arctanh() / 1.01) ** 2 + 1e-3 | |
loss_E = torch.sum(L_i * B_i) / torch.sum(B_i) | |
# loss_E = torch.mean( | |
# torch.square( | |
# (e_particles_pred.to(device) - e_particles.to(device)) | |
# / e_particles.to(device) | |
# ) | |
# ) | |
loss_momentum = torch.mean( | |
torch.square( | |
(mom_particles_pred.to(device) - mom_particles_true.to(device)) | |
/ mom_particles_true.to(device) | |
) | |
) | |
# loss_ce = torch.nn.BCELoss() | |
loss_mse = torch.nn.MSELoss() | |
loss_x = loss_mse(positions_particles_pred.to(device), x_particles.to(device)) | |
# loss_x = 0. # TEMPORARILY, there is some issue with X loss and it goes to \infty | |
# loss_particle_ids = loss_ce( | |
# pid_particles_pred.to(device), pid_particles_true.to(device) | |
# ) | |
# pid_true = pid_particles_true.argmax(dim=1).detach().tolist() | |
# pid_pred = pid_particles_pred.argmax(dim=1).detach().tolist() | |
# pid_true = [pid_dict[i.long().item()] for i in pid_true] | |
# pid_pred = [pid_dict[i.long().item()] for i in pid_pred] | |
# Connectivity matrix from hit (row) -> cluster (column) | |
# Index to matrix, e.g.: | |
# [1, 3, 1, 0] --> [ | |
# [0, 1, 0, 0], | |
# [0, 0, 0, 1], | |
# [0, 1, 0, 0], | |
# [1, 0, 0, 0] | |
# ] | |
M = torch.nn.functional.one_hot(cluster_index).long() | |
# Anti-connectivity matrix; be sure not to connect hits to clusters in different events! | |
M_inv = get_inter_event_norms_mask(batch, n_clusters_per_event) - M | |
# Throw away noise cluster columns; we never need them | |
M = M[:, is_object] | |
M_inv = M_inv[:, is_object] | |
assert M.size() == (n_hits, n_objects) | |
assert M_inv.size() == (n_hits, n_objects) | |
# Calculate all norms | |
# Warning: Should not be used without a mask! | |
# Contains norms between hits and objects from different events | |
# (n_hits, 1, cluster_space_dim) - (1, n_objects, cluster_space_dim) | |
# gives (n_hits, n_objects, cluster_space_dim) | |
norms = (cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)).norm(dim=-1) | |
assert norms.size() == (n_hits, n_objects) | |
L_clusters = torch.tensor(0.0).to(device) | |
if frac_combinations != 0: | |
L_clusters = L_clusters_calc( | |
batch, cluster_space_coords, cluster_index, frac_combinations, q | |
) | |
# ------- | |
# Attractive potential term | |
# First get all the relevant norms: We only want norms of signal hits | |
# w.r.t. the object they belong to, i.e. no noise hits and no noise clusters. | |
# First select all norms of all signal hits w.r.t. all objects, mask out later | |
if hgcal_implementation: | |
N_k = torch.sum(M, dim=0) # number of hits per object | |
norms = torch.sum( | |
torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)), | |
dim=-1, | |
) | |
norms_att = norms[is_sig] | |
#! att func as in line 159 of object condensation | |
norms_att = torch.log( | |
torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att / 2 + 1 | |
) | |
# Power-scale the norms | |
elif huberize_norm_for_V_attractive: | |
norms_att = norms[is_sig] | |
# Huberized version (linear but times 4) | |
# Be sure to not move 'off-diagonal' away from zero | |
# (i.e. norms of hits w.r.t. clusters they do _not_ belong to) | |
norms_att = huber(norms_att + 1e-5, 4.0) | |
else: | |
norms_att = norms[is_sig] | |
# Paper version is simply norms squared (no need for mask) | |
norms_att = norms_att**2 | |
assert norms_att.size() == (n_hits_sig, n_objects) | |
# Now apply the mask to keep only norms of signal hits w.r.t. to the object | |
# they belong to | |
norms_att *= M[is_sig] | |
# Final potential term | |
# (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
V_attractive = q[is_sig].unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att | |
assert V_attractive.size() == (n_hits_sig, n_objects) | |
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum over events | |
if hgcal_implementation: | |
#! each shower is account for separately | |
V_attractive = V_attractive.sum(dim=0) # K objects | |
#! divide by the number of accounted points | |
V_attractive = V_attractive.view(-1) / ( | |
N_k.view(-1) + 1e-3 | |
) # every object is accounted for equally | |
# if not tracking: | |
# #! add to terms function (divide by total number of showers per event) | |
# # L_V_attractive = scatter_add(V_attractive, object_index) / n_objects | |
# # L_V_attractive = torch.mean( | |
# # V_attractive | |
# # ) # V_attractive size n_objects, so per shower metric | |
# per_shower_weight = torch.exp(1 / (e_particles_pred_per_object + 0.4)) | |
# soft_m = torch.nn.Softmax(dim=0) | |
# per_shower_weight = soft_m(per_shower_weight) * len(V_attractive) | |
# L_V_attractive = torch.mean(V_attractive * per_shower_weight) | |
# else: | |
# weight classes by bin | |
# if tracking: | |
# e_true = y[:, 5].clone() | |
# # e_true_particle = e_true[object_index] | |
# label = 1 * (e_true > 4) | |
# V = label.size(0) | |
# n_classes = 2 | |
# label_count = torch.bincount(label) | |
# label_count = label_count[label_count.nonzero()].squeeze() | |
# cluster_sizes = torch.zeros(n_classes).long().to(label_count.device) | |
# cluster_sizes[torch.unique(label)] = label_count | |
# weight = (V - cluster_sizes).float() / V | |
# weight *= (cluster_sizes > 0).float() | |
# per_shower_weight = weight[label] | |
# soft_m = torch.nn.Softmax(dim=0) | |
# per_shower_weight = soft_m(per_shower_weight) * len(V_attractive) | |
# L_V_attractive = torch.mean(V_attractive * per_shower_weight) | |
# else: | |
L_V_attractive = torch.mean(V_attractive) | |
else: | |
#! in comparison this works per hit | |
V_attractive = ( | |
scatter_add(V_attractive.sum(dim=0), batch_object) / n_hits_per_event | |
) | |
assert V_attractive.size() == (batch_size,) | |
L_V_attractive = V_attractive.sum() | |
# ------- | |
# Repulsive potential term | |
# Get all the relevant norms: We want norms of any hit w.r.t. to | |
# objects they do *not* belong to, i.e. no noise clusters. | |
# We do however want to keep norms of noise hits w.r.t. objects | |
# Power-scale the norms: Gaussian scaling term instead of a cone | |
# Mask out the norms of hits w.r.t. the cluster they belong to | |
if hgcal_implementation: | |
norms_rep = torch.exp(-(norms) / 2) * M_inv | |
norms_rep2 = torch.exp(-(norms) * 5) * M_inv | |
else: | |
norms_rep = torch.exp(-4.0 * norms**2) * M_inv | |
# (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
V_repulsive = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep | |
V_repulsive2 = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep2 | |
# No need to apply a V = max(0, V); by construction V>=0 | |
assert V_repulsive.size() == (n_hits, n_objects) | |
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum up events | |
nope = n_objects_per_event - 1 | |
nope[nope == 0] = 1 | |
if hgcal_implementation: | |
#! sum each object repulsive terms | |
L_V_repulsive = V_repulsive.sum(dim=0) # size number of objects | |
number_of_repulsive_terms_per_object = torch.sum(M_inv, dim=0) | |
L_V_repulsive = L_V_repulsive.view( | |
-1 | |
) / number_of_repulsive_terms_per_object.view(-1) | |
L_V_repulsive2 = V_repulsive2.sum(dim=0) # size number of objects | |
L_V_repulsive2 = L_V_repulsive2.view(-1) | |
# if not tracking: | |
# #! add to terms function (divide by total number of showers per event) | |
# # L_V_repulsive = scatter_add(L_V_repulsive, object_index) / n_objects | |
# per_shower_weight = torch.exp(1 / (e_particles_pred_per_object + 0.4)) | |
# soft_m = torch.nn.Softmax(dim=0) | |
# per_shower_weight = soft_m(per_shower_weight) * len(L_V_repulsive) | |
# L_V_repulsive = torch.mean(L_V_repulsive * per_shower_weight) | |
# else: | |
# if tracking: | |
# L_V_repulsive = torch.mean(L_V_repulsive * per_shower_weight) | |
# else: | |
L_V_repulsive = torch.mean(L_V_repulsive) | |
L_V_repulsive2 = torch.mean(L_V_repulsive) | |
else: | |
L_V_repulsive = ( | |
scatter_add(V_repulsive.sum(dim=0), batch_object) | |
/ (n_hits_per_event * nope) | |
).sum() | |
L_V = ( | |
attr_weight * L_V_attractive | |
# + repul_weight * L_V_repulsive | |
+ L_V_repulsive2 | |
# + L_clusters | |
# + fill_loss | |
) | |
if L_clusters != 0: | |
print( | |
"L-clusters is", | |
100 * (L_clusters / L_V).detach().cpu().item(), | |
"% of L_V. L_clusters value:", | |
L_clusters.detach().cpu().item(), | |
) | |
# else: | |
# print("L-clusters is ZERO") | |
# ________________________________ | |
# L_beta term | |
# ------- | |
# L_beta noise term | |
n_noise_hits_per_event = scatter_count(batch[is_noise]) | |
n_noise_hits_per_event[n_noise_hits_per_event == 0] = 1 | |
L_beta_noise = ( | |
s_B | |
* ( | |
(scatter_add(beta[is_noise], batch[is_noise])) / n_noise_hits_per_event | |
).sum() | |
) | |
# print("L_beta_noise", L_beta_noise / batch_size) | |
# ------- | |
# L_beta signal term | |
if hgcal_implementation: | |
# version one: | |
beta_per_object_c = scatter_add(beta[is_sig], object_index) | |
beta_alpha = beta[is_sig][index_alpha] | |
L_beta_sig = torch.mean( | |
1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1) | |
) | |
# this is also per object so not dividing by batch size | |
# version 2 with the LSE approximation for the max | |
# eps = 1e-3 | |
# beta_per_object = scatter_add(torch.exp(beta[is_sig] / eps), object_index) | |
# beta_pen = 1 - eps * torch.log(beta_per_object) | |
# beta_per_object_c = scatter_add(beta[is_sig], object_index) | |
# beta_pen = beta_pen + 1 - torch.clip(beta_per_object_c, 0, 1) | |
# L_beta_sig = beta_pen.sum() / len(beta_pen) | |
# L_beta_sig = L_beta_sig / 4 | |
L_beta_noise = L_beta_noise / batch_size | |
# ? note: the training that worked quite well was dividing this by the batch size (1/4) | |
elif beta_term_option == "paper": | |
beta_alpha = beta[is_sig][index_alpha] | |
L_beta_sig = torch.sum( # maybe 0.5 for less aggressive loss | |
scatter_add((1 - beta_alpha), batch_object) / n_objects_per_event | |
) | |
# print("L_beta_sig", L_beta_sig / batch_size) | |
# beta_exp = beta[is_sig] | |
# beta_exp[index_alpha] = 0 | |
# # L_exp = torch.mean(beta_exp) | |
# beta_exp = torch.exp(0.5 * beta_exp) | |
# L_exp = torch.mean(scatter_add(beta_exp, batch) / n_hits_per_event) | |
elif beta_term_option == "short-range-potential": | |
# First collect the norms: We only want norms of hits w.r.t. the object they | |
# belong to (like in V_attractive) | |
# Apply transformation first, and then apply mask to keep only the norms we want, | |
# then sum over hits, so the result is (n_objects,) | |
norms_beta_sig = (1.0 / (20.0 * norms[is_sig] ** 2 + 1.0) * M[is_sig]).sum( | |
dim=0 | |
) | |
assert torch.all(norms_beta_sig >= 1.0) and torch.all( | |
norms_beta_sig <= n_hits_per_object | |
) | |
# Subtract from 1. to remove self interaction, divide by number of hits per object | |
norms_beta_sig = (1.0 - norms_beta_sig) / n_hits_per_object | |
assert torch.all(norms_beta_sig >= -1.0) and torch.all(norms_beta_sig <= 0.0) | |
norms_beta_sig *= beta_alpha | |
# Conclusion: | |
# lower beta --> higher loss (less negative) | |
# higher norms --> higher loss | |
# Sum over objects, divide by number of objects per event, then sum over events | |
L_beta_norms_term = ( | |
scatter_add(norms_beta_sig, batch_object) / n_objects_per_event | |
).sum() | |
assert L_beta_norms_term >= -batch_size and L_beta_norms_term <= 0.0 | |
# Logbeta term: Take -.2*torch.log(beta_alpha[is_object]+1e-9), sum it over objects, | |
# divide by n_objects_per_event, then sum over events (same pattern as above) | |
# lower beta --> higher loss | |
L_beta_logbeta_term = ( | |
scatter_add(-0.2 * torch.log(beta_alpha + 1e-9), batch_object) | |
/ n_objects_per_event | |
).sum() | |
# Final L_beta term | |
L_beta_sig = L_beta_norms_term + L_beta_logbeta_term | |
else: | |
valid_options = ["paper", "short-range-potential"] | |
raise ValueError( | |
f'beta_term_option "{beta_term_option}" is not valid, choose from {valid_options}' | |
) | |
L_beta = L_beta_noise + L_beta_sig | |
L_alpha_coordinates = torch.mean(torch.norm(x_alpha_original - x_alpha, p=2, dim=1)) | |
# ________________________________ | |
# Returning | |
# Also divide by batch size here | |
if return_components or DEBUG: | |
components = dict( | |
L_V=L_V / batch_size, | |
L_V_attractive=L_V_attractive / batch_size, | |
L_V_repulsive=L_V_repulsive / batch_size, | |
L_beta=L_beta / batch_size, | |
L_beta_noise=L_beta_noise / batch_size, | |
L_beta_sig=L_beta_sig / batch_size, | |
) | |
if beta_term_option == "short-range-potential": | |
components["L_beta_norms_term"] = L_beta_norms_term / batch_size | |
components["L_beta_logbeta_term"] = L_beta_logbeta_term / batch_size | |
if DEBUG: | |
debug(formatted_loss_components_string(components)) | |
if torch.isnan(L_beta / batch_size): | |
print("isnan!!!") | |
print(L_beta, batch_size) | |
print("L_beta_noise", L_beta_noise) | |
print("L_beta_sig", L_beta_sig) | |
if not tracking: | |
e_particles_pred = e_particles_pred.detach().to("cpu").flatten() | |
e_particles = e_particles.detach().to("cpu").flatten() | |
positions_particles_pred = positions_particles_pred.detach().to("cpu").flatten() | |
x_particles = x_particles.detach().to("cpu").flatten() | |
mom_particles_pred = mom_particles_pred.detach().flatten().to("cpu") | |
mom_particles_true = mom_particles_true.detach().flatten().to("cpu") | |
resolutions = { | |
"momentum_res": ( | |
(mom_particles_pred - mom_particles_true) / mom_particles_true | |
), | |
"e_res": ((e_particles_pred - e_particles) / e_particles).tolist(), | |
"pos_res": ( | |
(positions_particles_pred - x_particles) / x_particles | |
).tolist(), | |
} | |
# also return pid_true an<d pid_pred here to log the confusion matrix at each validation step | |
# try: | |
# L_clusters = L_clusters.detach().cpu().item() # if L_clusters is zero | |
# except: | |
# pass | |
L_exp = L_beta | |
if hgcal_implementation: | |
if not tracking: | |
return ( | |
L_V, # 0 | |
L_beta, | |
loss_E, | |
loss_x, | |
None, # loss_particle_ids0, # 4 | |
loss_momentum, | |
loss_mass, | |
None, # pid_true, | |
None, # pid_pred, | |
resolutions, | |
L_clusters, # 10 | |
fill_loss, | |
L_V_attractive, | |
L_V_repulsive, | |
L_alpha_coordinates, | |
L_exp, | |
norms_rep, # 16 | |
norms_att, # 17 | |
L_V_repulsive2, | |
) | |
else: | |
return ( | |
L_V, # 0 | |
L_beta, | |
L_V_attractive, | |
L_V_repulsive, | |
L_beta_sig, | |
L_beta_noise, | |
) | |
else: | |
if not tracking: | |
return ( | |
L_V / batch_size, # 0 | |
L_beta / batch_size, | |
loss_E, | |
loss_x, | |
None, # loss_particle_ids0, # 4 | |
loss_momentum, | |
loss_mass, | |
None, # pid_true, | |
None, # pid_pred, | |
resolutions, | |
L_clusters, # 10 | |
fill_loss, | |
L_V_attractive / batch_size, | |
L_V_repulsive / batch_size, | |
L_alpha_coordinates, | |
L_exp, | |
norms_rep, # 16 | |
norms_att, # 17 | |
) | |
def calc_LV_Lbeta_inference( | |
g, | |
distance_threshold, | |
energy_correction, | |
momentum: torch.Tensor, | |
beta: torch.Tensor, | |
cluster_space_coords: torch.Tensor, # Predicted by model | |
cluster_index_per_event: torch.Tensor, # inferred cluster_index_per_event | |
batch: torch.Tensor, | |
predicted_pid: torch.Tensor, # predicted PID embeddings - will be aggregated by summing up the clusters and applying the post_pid_pool_module MLP afterwards | |
post_pid_pool_module: torch.nn.Module, # MLP to apply to the pooled embeddings to get the PID predictions | |
# From here on just parameters | |
qmin: float = 0.1, | |
s_B: float = 1.0, | |
beta_stabilizing="soft_q_scaling", | |
huberize_norm_for_V_attractive=False, | |
beta_term_option="paper", | |
) -> Union[Tuple[torch.Tensor, torch.Tensor], dict]: | |
""" | |
Calculates the L_V and L_beta object condensation losses. | |
Concepts: | |
- A hit belongs to exactly one cluster (cluster_index_per_event is (n_hits,)), | |
and to exactly one event (batch is (n_hits,)) | |
- A cluster index of `noise_cluster_index` means the cluster is a noise cluster. | |
There is typically one noise cluster per event. Any hit in a noise cluster | |
is a 'noise hit'. A hit in an object is called a 'signal hit' for lack of a | |
better term. | |
- An 'object' is a cluster that is *not* a noise cluster. | |
beta_stabilizing: Choices are ['paper', 'clip', 'soft_q_scaling']: | |
paper: beta is sigmoid(model_output), q = beta.arctanh()**2 + qmin | |
clip: beta is clipped to 1-1e-4, q = beta.arctanh()**2 + qmin | |
soft_q_scaling: beta is sigmoid(model_output), q = (clip(beta)/1.002).arctanh()**2 + qmin | |
huberize_norm_for_V_attractive: Huberizes the norms when used in the attractive potential | |
beta_term_option: Choices are ['paper', 'short-range-potential']: | |
Choosing 'short-range-potential' introduces a short range potential around high | |
beta points, acting like V_attractive. | |
Note this function has modifications w.r.t. the implementation in 2002.03605: | |
- The norms for V_repulsive are now Gaussian (instead of linear hinge) | |
""" | |
# remove dummy rows added for dataloader # TODO think of better way to do this | |
device = beta.device | |
# alert the user if there are nans | |
if torch.isnan(beta).any(): | |
print("There are nans in beta!", len(beta[torch.isnan(beta)])) | |
beta = torch.nan_to_num(beta, nan=0.0) | |
assert_no_nans(beta) | |
# ________________________________ | |
# Calculate a bunch of needed counts and indices locally | |
# cluster_index: unique index over events | |
# E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1] | |
# -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ] | |
cluster_index, n_clusters_per_event = batch_cluster_indices( | |
cluster_index_per_event, batch | |
) | |
n_clusters = n_clusters_per_event.sum() | |
n_hits, cluster_space_dim = cluster_space_coords.size() | |
batch_size = batch.max() + 1 | |
n_hits_per_event = scatter_count(batch) | |
# Index of cluster -> event (n_clusters,) | |
# batch_cluster = scatter_counts_to_indices(n_clusters_per_event) | |
# Per-hit boolean, indicating whether hit is sig or noise | |
# is_noise = cluster_index_per_event == noise_cluster_index | |
##is_sig = ~is_noise | |
# n_hits_sig = is_sig.sum() | |
# n_sig_hits_per_event = scatter_count(batch[is_sig]) | |
# Per-cluster boolean, indicating whether cluster is an object or noise | |
# is_object = scatter_max(is_sig.long(), cluster_index)[0].bool() | |
# is_noise_cluster = ~is_object | |
# FIXME: This assumes noise_cluster_index == 0!! | |
# Not sure how to do this in a performant way in case noise_cluster_index != 0 | |
# if noise_cluster_index != 0: | |
# raise NotImplementedError | |
# object_index_per_event = cluster_index_per_event[is_sig] - 1 | |
# object_index, n_objects_per_event = batch_cluster_indices( | |
# object_index_per_event, batch[is_sig] | |
# ) | |
# n_hits_per_object = scatter_count(object_index) | |
# print("n_hits_per_object", n_hits_per_object) | |
# batch_object = batch_cluster[is_object] | |
# n_objects = is_object.sum() | |
# assert object_index.size() == (n_hits_sig,) | |
# assert is_object.size() == (n_clusters,) | |
# assert torch.all(n_hits_per_object > 0) | |
# assert object_index.max() + 1 == n_objects | |
# ________________________________ | |
# L_V term | |
# Calculate q | |
if beta_stabilizing == "paper": | |
q = beta.arctanh() ** 2 + qmin | |
elif beta_stabilizing == "clip": | |
beta = beta.clip(0.0, 1 - 1e-4) | |
q = beta.arctanh() ** 2 + qmin | |
elif beta_stabilizing == "soft_q_scaling": | |
q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin | |
else: | |
raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known") | |
if torch.isnan(beta).any(): | |
print("There are nans in beta!", len(beta[torch.isnan(beta)])) | |
beta = torch.nan_to_num(beta, nan=0.0) | |
assert_no_nans(q) | |
assert q.device == device | |
assert q.size() == (n_hits,) | |
# TODO: continue here | |
# Calculate q_alpha, the max q per object, and the indices of said maxima | |
q_alpha, index_alpha = scatter_max(q, cluster_index) | |
assert q_alpha.size() == (n_clusters,) | |
# Get the cluster space coordinates and betas for these maxima hits too | |
index_alpha -= 1 # why do we need this? | |
x_alpha = cluster_space_coords[index_alpha] | |
beta_alpha = beta[index_alpha] | |
positions_particles_pred = g.ndata["pos_hits_xyz"][index_alpha] | |
positions_particles_pred = ( | |
positions_particles_pred + distance_threshold[index_alpha] | |
) | |
is_sig_everything = torch.ones_like(batch).bool() | |
e_particles_pred, pid_particles_pred, mom_particles_pred = calc_energy_pred( | |
batch, | |
g, | |
cluster_index_per_event, | |
is_sig_everything, | |
q, | |
beta, | |
energy_correction, | |
predicted_pid, | |
momentum, | |
) | |
pid_particles_pred = post_pid_pool_module( | |
pid_particles_pred | |
) # project the pooled PID embeddings to the final "one hot encoding" space | |
mass_particles_pred = e_particles_pred**2 - mom_particles_pred**2 | |
mass_particles_pred[mass_particles_pred < 0] = 0.0 | |
mass_particles_pred = torch.sqrt(mass_particles_pred) | |
pid_pred = pid_particles_pred.argmax(dim=1).detach().tolist() | |
return ( | |
pid_pred, | |
pid_particles_pred, | |
mass_particles_pred, | |
e_particles_pred, | |
mom_particles_pred, | |
) | |
def formatted_loss_components_string(components: dict) -> str: | |
""" | |
Formats the components returned by calc_LV_Lbeta | |
""" | |
total_loss = components["L_V"] + components["L_beta"] | |
fractions = {k: v / total_loss for k, v in components.items()} | |
fkey = lambda key: f"{components[key]:+.4f} ({100.*fractions[key]:.1f}%)" | |
s = ( | |
" L_V = {L_V}" | |
"\n L_V_attractive = {L_V_attractive}" | |
"\n L_V_repulsive = {L_V_repulsive}" | |
"\n L_beta = {L_beta}" | |
"\n L_beta_noise = {L_beta_noise}" | |
"\n L_beta_sig = {L_beta_sig}".format( | |
L=total_loss, **{k: fkey(k) for k in components} | |
) | |
) | |
if "L_beta_norms_term" in components: | |
s += ( | |
"\n L_beta_norms_term = {L_beta_norms_term}" | |
"\n L_beta_logbeta_term = {L_beta_logbeta_term}".format( | |
**{k: fkey(k) for k in components} | |
) | |
) | |
if "L_noise_filter" in components: | |
s += f'\n L_noise_filter = {fkey("L_noise_filter")}' | |
return s | |
def calc_simple_clus_space_loss( | |
cluster_space_coords: torch.Tensor, # Predicted by model | |
cluster_index_per_event: torch.Tensor, # Truth hit->cluster index | |
batch: torch.Tensor, | |
# From here on just parameters | |
noise_cluster_index: int = 0, # cluster_index entries with this value are noise/noise | |
huberize_norm_for_V_attractive=True, | |
pred_edc: torch.Tensor = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Isolating just the V_attractive and V_repulsive parts of object condensation, | |
w.r.t. the geometrical mean of truth cluster centers (rather than the highest | |
beta point of the truth cluster). | |
Most of this code is copied from `calc_LV_Lbeta`, so it's easier to try out | |
different scalings for the norms without breaking the main OC function. | |
`pred_edc`: Predicted estimated distance-to-center. | |
This is an optional column, that should be `n_hits` long. If it is | |
passed, a third loss component is calculated based on the truth distance-to-center | |
w.r.t. predicted distance-to-center. This quantifies how close a hit is to it's center, | |
which provides an ansatz for the clustering. | |
See also the 'Concepts' in the doc of `calc_LV_Lbeta`. | |
""" | |
# ________________________________ | |
# Calculate a bunch of needed counts and indices locally | |
# cluster_index: unique index over events | |
# E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1] | |
# -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ] | |
cluster_index, n_clusters_per_event = batch_cluster_indices( | |
cluster_index_per_event, batch | |
) | |
n_hits, cluster_space_dim = cluster_space_coords.size() | |
batch_size = batch.max() + 1 | |
n_hits_per_event = scatter_count(batch) | |
# Index of cluster -> event (n_clusters,) | |
batch_cluster = scatter_counts_to_indices(n_clusters_per_event) | |
# Per-hit boolean, indicating whether hit is sig or noise | |
is_noise = cluster_index_per_event == noise_cluster_index | |
is_sig = ~is_noise | |
n_hits_sig = is_sig.sum() | |
# Per-cluster boolean, indicating whether cluster is an object or noise | |
is_object = scatter_max(is_sig.long(), cluster_index)[0].bool() | |
# # FIXME: This assumes noise_cluster_index == 0!! | |
# # Not sure how to do this in a performant way in case noise_cluster_index != 0 | |
# if noise_cluster_index != 0: raise NotImplementedError | |
# object_index_per_event = cluster_index_per_event[is_sig] - 1 | |
batch_object = batch_cluster[is_object] | |
n_objects = is_object.sum() | |
# ________________________________ | |
# Build the masks | |
# Connectivity matrix from hit (row) -> cluster (column) | |
# Index to matrix, e.g.: | |
# [1, 3, 1, 0] --> [ | |
# [0, 1, 0, 0], | |
# [0, 0, 0, 1], | |
# [0, 1, 0, 0], | |
# [1, 0, 0, 0] | |
# ] | |
M = torch.nn.functional.one_hot(cluster_index).long() | |
# Anti-connectivity matrix; be sure not to connect hits to clusters in different events! | |
M_inv = get_inter_event_norms_mask(batch, n_clusters_per_event) - M | |
# Throw away noise cluster columns; we never need them | |
M = M[:, is_object] | |
M_inv = M_inv[:, is_object] | |
assert M.size() == (n_hits, n_objects) | |
assert M_inv.size() == (n_hits, n_objects) | |
# ________________________________ | |
# Loss terms | |
# First calculate all cluster centers, then throw out the noise clusters | |
cluster_centers = scatter_mean(cluster_space_coords, cluster_index, dim=0) | |
object_centers = cluster_centers[is_object] | |
# Calculate all norms | |
# Warning: Should not be used without a mask! | |
# Contains norms between hits and objects from different events | |
# (n_hits, 1, cluster_space_dim) - (1, n_objects, cluster_space_dim) | |
# gives (n_hits, n_objects, cluster_space_dim) | |
norms = (cluster_space_coords.unsqueeze(1) - object_centers.unsqueeze(0)).norm( | |
dim=-1 | |
) | |
assert norms.size() == (n_hits, n_objects) | |
# ------- | |
# Attractive loss | |
# First get all the relevant norms: We only want norms of signal hits | |
# w.r.t. the object they belong to, i.e. no noise hits and no noise clusters. | |
# First select all norms of all signal hits w.r.t. all objects (filtering out | |
# the noise), mask out later | |
norms_att = norms[is_sig] | |
# Power-scale the norms | |
if huberize_norm_for_V_attractive: | |
# Huberized version (linear but times 4) | |
# Be sure to not move 'off-diagonal' away from zero | |
# (i.e. norms of hits w.r.t. clusters they do _not_ belong to) | |
norms_att = huber(norms_att + 1e-5, 4.0) | |
else: | |
# Paper version is simply norms squared (no need for mask) | |
norms_att = norms_att**2 | |
assert norms_att.size() == (n_hits_sig, n_objects) | |
# Now apply the mask to keep only norms of signal hits w.r.t. to the object | |
# they belong to (throw away norms w.r.t. cluster they do *not* belong to) | |
norms_att *= M[is_sig] | |
# Sum norms_att over hits (dim=0), then sum per event, then divide by n_hits_per_event, | |
# then sum over events | |
L_attractive = ( | |
scatter_add(norms_att.sum(dim=0), batch_object) / n_hits_per_event | |
).sum() | |
# ------- | |
# Repulsive loss | |
# Get all the relevant norms: We want norms of any hit w.r.t. to | |
# objects they do *not* belong to, i.e. no noise clusters. | |
# We do however want to keep norms of noise hits w.r.t. objects | |
# Power-scale the norms: Gaussian scaling term instead of a cone | |
# Mask out the norms of hits w.r.t. the cluster they belong to | |
norms_rep = torch.exp(-4.0 * norms**2) * M_inv | |
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum up events | |
L_repulsive = ( | |
scatter_add(norms_rep.sum(dim=0), batch_object) / n_hits_per_event | |
).sum() | |
L_attractive /= batch_size | |
L_repulsive /= batch_size | |
# ------- | |
# Optional: edc column | |
if pred_edc is not None: | |
n_hits_per_cluster = scatter_count(cluster_index) | |
cluster_centers_expanded = torch.index_select(cluster_centers, 0, cluster_index) | |
assert cluster_centers_expanded.size() == (n_hits, cluster_space_dim) | |
truth_edc = (cluster_space_coords - cluster_centers_expanded).norm(dim=-1) | |
assert pred_edc.size() == (n_hits,) | |
d_per_hit = (pred_edc - truth_edc) ** 2 | |
d_per_object = scatter_add(d_per_hit, cluster_index)[is_object] | |
assert d_per_object.size() == (n_objects,) | |
L_edc = (scatter_add(d_per_object, batch_object) / n_hits_per_event).sum() | |
return L_attractive, L_repulsive, L_edc | |
return L_attractive, L_repulsive | |
def huber(d, delta): | |
""" | |
See: https://en.wikipedia.org/wiki/Huber_loss#Definition | |
Multiplied by 2 w.r.t Wikipedia version (aligning with Jan's definition) | |
""" | |
return torch.where( | |
torch.abs(d) <= delta, d**2, 2.0 * delta * (torch.abs(d) - delta) | |
) | |
def batch_cluster_indices( | |
cluster_id: torch.Tensor, batch: torch.Tensor | |
) -> Tuple[torch.LongTensor, torch.LongTensor]: | |
""" | |
Turns cluster indices per event to an index in the whole batch | |
Example: | |
cluster_id = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1]) | |
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2]) | |
--> | |
offset = torch.LongTensor([0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 5, 5, 5]) | |
output = torch.LongTensor([0, 0, 1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 6]) | |
""" | |
device = cluster_id.device | |
assert cluster_id.device == batch.device | |
# Count the number of clusters per entry in the batch | |
n_clusters_per_event = scatter_max(cluster_id, batch, dim=-1)[0] + 1 | |
# Offsets are then a cumulative sum | |
offset_values_nozero = n_clusters_per_event[:-1].cumsum(dim=-1) | |
# Prefix a zero | |
offset_values = torch.cat((torch.zeros(1, device=device), offset_values_nozero)) | |
# Fill it per hit | |
offset = torch.gather(offset_values, 0, batch).long() | |
return offset + cluster_id, n_clusters_per_event | |
def get_clustering_np( | |
betas: np.array, X: np.array, tbeta: float = 0.1, td: float = 1.0 | |
) -> np.array: | |
""" | |
Returns a clustering of hits -> cluster_index, based on the GravNet model | |
output (predicted betas and cluster space coordinates) and the clustering | |
parameters tbeta and td. | |
Takes numpy arrays as input. | |
""" | |
n_points = betas.shape[0] | |
select_condpoints = betas > tbeta | |
# Get indices passing the threshold | |
indices_condpoints = np.nonzero(select_condpoints)[0] | |
# Order them by decreasing beta value | |
indices_condpoints = indices_condpoints[np.argsort(-betas[select_condpoints])] | |
# Assign points to condensation points | |
# Only assign previously unassigned points (no overwriting) | |
# Points unassigned at the end are bkg (-1) | |
unassigned = np.arange(n_points) | |
clustering = -1 * np.ones(n_points, dtype=np.int32) | |
for index_condpoint in indices_condpoints: | |
d = np.linalg.norm(X[unassigned] - X[index_condpoint], axis=-1) | |
assigned_to_this_condpoint = unassigned[d < td] | |
clustering[assigned_to_this_condpoint] = index_condpoint | |
unassigned = unassigned[~(d < td)] | |
return clustering | |
def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=1.0): | |
""" | |
Returns a clustering of hits -> cluster_index, based on the GravNet model | |
output (predicted betas and cluster space coordinates) and the clustering | |
parameters tbeta and td. | |
Takes torch.Tensors as input. | |
""" | |
n_points = betas.size(0) | |
select_condpoints = betas > tbeta | |
# Get indices passing the threshold | |
indices_condpoints = select_condpoints.nonzero() | |
# Order them by decreasing beta value | |
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
# Assign points to condensation points | |
# Only assign previously unassigned points (no overwriting) | |
# Points unassigned at the end are bkg (-1) | |
unassigned = torch.arange(n_points) | |
clustering = -1 * torch.ones(n_points, dtype=torch.long) | |
for index_condpoint in indices_condpoints: | |
d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) | |
assigned_to_this_condpoint = unassigned[d < td] | |
clustering[assigned_to_this_condpoint] = index_condpoint[0] | |
unassigned = unassigned[~(d < td)] | |
return clustering | |
def scatter_count(input: torch.Tensor): | |
""" | |
Returns ordered counts over an index array | |
Example: | |
>>> scatter_count(torch.Tensor([0, 0, 0, 1, 1, 2, 2])) # input | |
>>> [3, 2, 2] | |
Index assumptions work like in torch_scatter, so: | |
>>> scatter_count(torch.Tensor([1, 1, 1, 2, 2, 4, 4])) | |
>>> tensor([0, 3, 2, 0, 2]) | |
""" | |
return scatter_add(torch.ones_like(input, dtype=torch.long), input.long()) | |
def scatter_counts_to_indices(input: torch.LongTensor) -> torch.LongTensor: | |
""" | |
Converts counts to indices. This is the inverse operation of scatter_count | |
Example: | |
input: [3, 2, 2] | |
output: [0, 0, 0, 1, 1, 2, 2] | |
""" | |
return torch.repeat_interleave( | |
torch.arange(input.size(0), device=input.device), input | |
).long() | |
def get_inter_event_norms_mask( | |
batch: torch.LongTensor, nclusters_per_event: torch.LongTensor | |
): | |
""" | |
Creates mask of (nhits x nclusters) that is only 1 if hit i is in the same event as cluster j | |
Example: | |
cluster_id_per_event = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1]) | |
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2]) | |
Should return: | |
torch.LongTensor([ | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 0, 0, 1, 1], | |
[0, 0, 0, 0, 0, 1, 1], | |
[0, 0, 0, 0, 0, 1, 1], | |
]) | |
""" | |
device = batch.device | |
# Following the example: | |
# Expand batch to the following (nhits x nevents) matrix (little hacky, boolean mask -> long): | |
# [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]] | |
batch_expanded_as_ones = ( | |
batch | |
== torch.arange(batch.max() + 1, dtype=torch.long, device=device).unsqueeze(-1) | |
).long() | |
# Then repeat_interleave it to expand it to nclusters rows, and transpose to get (nhits x nclusters) | |
return batch_expanded_as_ones.repeat_interleave(nclusters_per_event, dim=0).T | |
def isin(ar1, ar2): | |
"""To be replaced by torch.isin for newer releases of torch""" | |
return (ar1[..., None] == ar2).any(-1) | |
def reincrementalize(y: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: | |
"""Re-indexes y so that missing clusters are no longer counted. | |
Example: | |
>>> y = torch.LongTensor([ | |
0, 0, 0, 1, 1, 3, 3, | |
0, 0, 0, 0, 0, 2, 2, 3, 3, | |
0, 0, 1, 1 | |
]) | |
>>> batch = torch.LongTensor([ | |
0, 0, 0, 0, 0, 0, 0, | |
1, 1, 1, 1, 1, 1, 1, 1, 1, | |
2, 2, 2, 2, | |
]) | |
>>> print(reincrementalize(y, batch)) | |
tensor([0, 0, 0, 1, 1, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1]) | |
""" | |
y_offset, n_per_event = batch_cluster_indices(y, batch) | |
offset = y_offset - y | |
n_clusters = n_per_event.sum() | |
holes = ( | |
(~isin(torch.arange(n_clusters, device=y.device), y_offset)) | |
.nonzero() | |
.squeeze(-1) | |
) | |
n_per_event_without_holes = n_per_event.clone() | |
n_per_event_cumsum = n_per_event.cumsum(0) | |
for hole in holes.sort(descending=True).values: | |
y_offset[y_offset > hole] -= 1 | |
i_event = (hole > n_per_event_cumsum).long().argmin() | |
n_per_event_without_holes[i_event] -= 1 | |
offset_per_event = torch.zeros_like(n_per_event_without_holes) | |
offset_per_event[1:] = n_per_event_without_holes.cumsum(0)[:-1] | |
offset_without_holes = torch.gather(offset_per_event, 0, batch).long() | |
reincrementalized = y_offset - offset_without_holes | |
return reincrementalized | |
def L_clusters_calc(batch, cluster_space_coords, cluster_index, frac_combinations, q): | |
number_of_pairs = 0 | |
for batch_id in batch.unique(): | |
# do all possible pairs... | |
bmask = batch == batch_id | |
clust_space_filt = cluster_space_coords[bmask] | |
pos_pairs_all = [] | |
neg_pairs_all = [] | |
if len(cluster_index[bmask].unique()) <= 1: | |
continue | |
L_clusters = torch.tensor(0.0).to(q.device) | |
for cluster in cluster_index[bmask].unique(): | |
coords_pos = clust_space_filt[cluster_index[bmask] == cluster] | |
coords_neg = clust_space_filt[cluster_index[bmask] != cluster] | |
if len(coords_neg) == 0: | |
continue | |
clust_idx = cluster_index[bmask] == cluster | |
# all_ones = torch.ones_like((clust_idx, clust_idx)) | |
# pos_pairs = [[i, j] for i in range(len(coords_pos)) for j in range (len(coords_pos)) if i < j] | |
total_num = (len(coords_pos) ** 2) / 2 | |
num = int(frac_combinations * total_num) | |
pos_pairs = [] | |
for i in range(num): | |
pos_pairs.append( | |
[ | |
np.random.randint(len(coords_pos)), | |
np.random.randint(len(coords_pos)), | |
] | |
) | |
neg_pairs = [] | |
for i in range(len(pos_pairs)): | |
neg_pairs.append( | |
[ | |
np.random.randint(len(coords_pos)), | |
np.random.randint(len(coords_neg)), | |
] | |
) | |
pos_pairs_all += pos_pairs | |
neg_pairs_all += neg_pairs | |
pos_pairs = torch.tensor(pos_pairs_all) | |
neg_pairs = torch.tensor(neg_pairs_all) | |
"""# do just a small sample of the pairs. ... | |
bmask = batch == batch_id | |
#L_clusters = 0 # Loss of randomly sampled distances between points inside and outside clusters | |
pos_idx, neg_idx = [], [] | |
for cluster in cluster_index[bmask].unique(): | |
clust_idx = (cluster_index == cluster)[bmask] | |
perm = torch.randperm(clust_idx.sum()) | |
perm1 = torch.randperm((~clust_idx).sum()) | |
perm2 = torch.randperm(clust_idx.sum()) | |
#cutoff = clust_idx.sum()//2 | |
pos_lst = clust_idx.nonzero()[perm] | |
neg_lst = (~clust_idx).nonzero()[perm1] | |
neg_lst_second = clust_idx.nonzero()[perm2] | |
if len(pos_lst) % 2: | |
pos_lst = pos_lst[:-1] | |
if len(neg_lst) % 2: | |
neg_lst = neg_lst[:-1] | |
len_cap = min(len(pos_lst), len(neg_lst), len(neg_lst_second)) | |
if len_cap % 2: | |
len_cap -= 1 | |
pos_lst = pos_lst[:len_cap] | |
neg_lst = neg_lst[:len_cap] | |
neg_lst_second = neg_lst_second[:len_cap] | |
pos_pairs = pos_lst.reshape(-1, 2) | |
neg_pairs = torch.cat([neg_lst, neg_lst_second], dim=1) | |
neg_pairs = neg_pairs[:pos_lst.shape[0]//2, :] | |
pos_idx.append(pos_pairs) | |
neg_idx.append(neg_pairs) | |
pos_idx = torch.cat(pos_idx) | |
neg_idx = torch.cat(neg_idx)""" | |
assert pos_pairs.shape == neg_pairs.shape | |
if len(pos_pairs) == 0: | |
continue | |
cluster_space_coords_filtered = cluster_space_coords[bmask] | |
qs_filtered = q[bmask] | |
pos_norms = ( | |
cluster_space_coords_filtered[pos_pairs[:, 0]] | |
- cluster_space_coords_filtered[pos_pairs[:, 1]] | |
).norm(dim=-1) | |
neg_norms = ( | |
cluster_space_coords_filtered[neg_pairs[:, 0]] | |
- cluster_space_coords_filtered[neg_pairs[:, 1]] | |
).norm(dim=-1) | |
q_pos = qs_filtered[pos_pairs[:, 0]] | |
q_neg = qs_filtered[neg_pairs[:, 0]] | |
q_s = torch.cat([q_pos, q_neg]) | |
norms_pos = torch.cat([pos_norms, neg_norms]) | |
ys = torch.cat([torch.ones_like(pos_norms), -torch.ones_like(neg_norms)]) | |
L_clusters += torch.sum( | |
q_s * torch.nn.HingeEmbeddingLoss(reduce=None)(norms_pos, ys) | |
) | |
number_of_pairs += norms_pos.shape[0] | |
if number_of_pairs > 0: | |
L_clusters = L_clusters / number_of_pairs | |
return L_clusters | |
## deprecated code: | |