jetclustering / src /layers /obj_cond_inf.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
5.95 kB
from typing import Tuple, Union
import numpy as np
import torch
from torch_scatter import scatter_max, scatter_add, scatter_mean
from src.layers.object_cond import assert_no_nans, scatter_count, batch_cluster_indices
import dgl
def calc_energy_loss(
batch, cluster_space_coords, beta, beta_stabilizing="soft_q_scaling", qmin=0.1, radius=0.7,
e_frac_loss_return_particles=False, y=None, select_centers_by_particle=True
):
# select_centers_by_particle: if True, we pretend we know which hits belong to which particle...
list_graphs = dgl.unbatch(batch)
node_counter = 0
if beta_stabilizing == "paper":
q = beta.arctanh() ** 2 + qmin
elif beta_stabilizing == "clip":
beta = beta.clip(0.0, 1 - 1e-4)
q = beta.arctanh() ** 2 + qmin
elif beta_stabilizing == "soft_q_scaling":
q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin
else:
raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known")
loss_E_frac = []
loss_E_frac_true = []
particle_ids_all = []
reco_count = {} # per-PID count
non_reco_count = {}
total_count = {}
for g in list_graphs:
particle_id = g.ndata["particle_number"]
number_of_objects = len(particle_id.unique())
print("No. of objects", number_of_objects)
non = g.number_of_nodes()
q_g = q[node_counter : non + node_counter]
betas = beta[node_counter : non + node_counter]
sorted, indices = torch.sort(betas.view(-1), descending=False)
selected_centers = indices[0:number_of_objects]
_, selected_centers_particles = scatter_max(
betas.flatten().cpu(), particle_id.cpu().long() - 1
)
assert selected_centers.shape[0] == number_of_objects
if select_centers_by_particle:
selected_centers = selected_centers_particles.to(g.device)
all_particles = set((particle_id.unique()-1).long().tolist())
reco_particles = set((particle_id[selected_centers]-1).long().tolist())
non_reco_particles = all_particles - reco_particles
part_pids = y[:, 6].long()
for particle in all_particles:
curr_pid = part_pids[particle].item()
if curr_pid in total_count:
total_count[curr_pid] += 1
else:
total_count[curr_pid] = 1
if particle in reco_particles:
if curr_pid in reco_count:
reco_count[curr_pid] += 1
else:
reco_count[curr_pid] = 1
else:
if curr_pid in non_reco_count:
non_reco_count[curr_pid] += 1
else:
non_reco_count[curr_pid] = 1
X = cluster_space_coords[node_counter : non + node_counter]
if radius == "dynamic":
pick_ = torch.argsort(
torch.cdist(X[selected_centers], X[selected_centers], p=2),
dim=1)[:, 1]
current_radius = torch.cdist(torch.Tensor(X[selected_centers]), torch.Tensor(X[selected_centers]), p=2).gather(1, pick_.view(-1, 1))
current_radius = current_radius / 2
current_radius = max(0.1, current_radius.flatten().min())
print("Current radius", current_radius)
else:
print("Radius", radius)
current_radius = radius
clusterings = get_clustering(selected_centers, X, betas, td=current_radius)
clusterings = clusterings.to(g.device)
node_counter += non
counter = 0
frac_energy = []
frac_energy_true = []
particle_ids = []
for alpha in selected_centers:
id_particle = particle_id[alpha]
true_mask_particle = particle_id == id_particle
true_energy = torch.sum(g.ndata["e_hits"][true_mask_particle])
mask_clustering_particle = clusterings == counter
clustered_energy = torch.sum(g.ndata["e_hits"][mask_clustering_particle])
clustered_energy_true = torch.sum(
g.ndata["e_hits"][
mask_clustering_particle * true_mask_particle.flatten()
]
) # only consider how much has been correctly assigned
counter += 1
frac_energy.append(clustered_energy / (true_energy + 1e-7))
frac_energy_true.append(clustered_energy_true / (true_energy + 1e-7))
particle_ids.append(id_particle.cpu().long().item())
frac_energy = torch.stack(frac_energy, dim=0)
if not e_frac_loss_return_particles:
frac_energy = torch.mean(frac_energy)
frac_energy_true = torch.stack(frac_energy_true, dim=0)
if not e_frac_loss_return_particles:
frac_energy_true = torch.mean(frac_energy_true)
loss_E_frac.append(frac_energy)
loss_E_frac_true.append(frac_energy_true)
particle_ids_all.append(particle_ids)
if e_frac_loss_return_particles:
return loss_E_frac, [loss_E_frac_true, particle_ids_all, reco_count, non_reco_count, total_count]
loss_E_frac = torch.mean(torch.stack(loss_E_frac, dim=0))
loss_E_frac_true = torch.mean(torch.stack(loss_E_frac_true, dim=0))
return loss_E_frac, loss_E_frac_true
def get_clustering(index_alpha_i, X, betas, td=0.7):
n_points = betas.size(0)
unassigned = torch.arange(n_points).to(betas.device)
clustering = -1 * torch.ones(n_points, dtype=torch.long)
counter = 0
for index_condpoint in index_alpha_i:
d = torch.norm(X[unassigned] - X[index_condpoint], dim=-1)
assigned_to_this_condpoint = unassigned[d < td]
clustering[assigned_to_this_condpoint] = counter
unassigned = unassigned[~(d < td)]
counter = counter + 1
counter = 0
for index_condpoint in index_alpha_i:
clustering[index_condpoint] = counter
counter = counter + 1
return clustering