Spaces:
Sleeping
Sleeping
File size: 5,952 Bytes
e75a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Tuple, Union
import numpy as np
import torch
from torch_scatter import scatter_max, scatter_add, scatter_mean
from src.layers.object_cond import assert_no_nans, scatter_count, batch_cluster_indices
import dgl
def calc_energy_loss(
batch, cluster_space_coords, beta, beta_stabilizing="soft_q_scaling", qmin=0.1, radius=0.7,
e_frac_loss_return_particles=False, y=None, select_centers_by_particle=True
):
# select_centers_by_particle: if True, we pretend we know which hits belong to which particle...
list_graphs = dgl.unbatch(batch)
node_counter = 0
if beta_stabilizing == "paper":
q = beta.arctanh() ** 2 + qmin
elif beta_stabilizing == "clip":
beta = beta.clip(0.0, 1 - 1e-4)
q = beta.arctanh() ** 2 + qmin
elif beta_stabilizing == "soft_q_scaling":
q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin
else:
raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known")
loss_E_frac = []
loss_E_frac_true = []
particle_ids_all = []
reco_count = {} # per-PID count
non_reco_count = {}
total_count = {}
for g in list_graphs:
particle_id = g.ndata["particle_number"]
number_of_objects = len(particle_id.unique())
print("No. of objects", number_of_objects)
non = g.number_of_nodes()
q_g = q[node_counter : non + node_counter]
betas = beta[node_counter : non + node_counter]
sorted, indices = torch.sort(betas.view(-1), descending=False)
selected_centers = indices[0:number_of_objects]
_, selected_centers_particles = scatter_max(
betas.flatten().cpu(), particle_id.cpu().long() - 1
)
assert selected_centers.shape[0] == number_of_objects
if select_centers_by_particle:
selected_centers = selected_centers_particles.to(g.device)
all_particles = set((particle_id.unique()-1).long().tolist())
reco_particles = set((particle_id[selected_centers]-1).long().tolist())
non_reco_particles = all_particles - reco_particles
part_pids = y[:, 6].long()
for particle in all_particles:
curr_pid = part_pids[particle].item()
if curr_pid in total_count:
total_count[curr_pid] += 1
else:
total_count[curr_pid] = 1
if particle in reco_particles:
if curr_pid in reco_count:
reco_count[curr_pid] += 1
else:
reco_count[curr_pid] = 1
else:
if curr_pid in non_reco_count:
non_reco_count[curr_pid] += 1
else:
non_reco_count[curr_pid] = 1
X = cluster_space_coords[node_counter : non + node_counter]
if radius == "dynamic":
pick_ = torch.argsort(
torch.cdist(X[selected_centers], X[selected_centers], p=2),
dim=1)[:, 1]
current_radius = torch.cdist(torch.Tensor(X[selected_centers]), torch.Tensor(X[selected_centers]), p=2).gather(1, pick_.view(-1, 1))
current_radius = current_radius / 2
current_radius = max(0.1, current_radius.flatten().min())
print("Current radius", current_radius)
else:
print("Radius", radius)
current_radius = radius
clusterings = get_clustering(selected_centers, X, betas, td=current_radius)
clusterings = clusterings.to(g.device)
node_counter += non
counter = 0
frac_energy = []
frac_energy_true = []
particle_ids = []
for alpha in selected_centers:
id_particle = particle_id[alpha]
true_mask_particle = particle_id == id_particle
true_energy = torch.sum(g.ndata["e_hits"][true_mask_particle])
mask_clustering_particle = clusterings == counter
clustered_energy = torch.sum(g.ndata["e_hits"][mask_clustering_particle])
clustered_energy_true = torch.sum(
g.ndata["e_hits"][
mask_clustering_particle * true_mask_particle.flatten()
]
) # only consider how much has been correctly assigned
counter += 1
frac_energy.append(clustered_energy / (true_energy + 1e-7))
frac_energy_true.append(clustered_energy_true / (true_energy + 1e-7))
particle_ids.append(id_particle.cpu().long().item())
frac_energy = torch.stack(frac_energy, dim=0)
if not e_frac_loss_return_particles:
frac_energy = torch.mean(frac_energy)
frac_energy_true = torch.stack(frac_energy_true, dim=0)
if not e_frac_loss_return_particles:
frac_energy_true = torch.mean(frac_energy_true)
loss_E_frac.append(frac_energy)
loss_E_frac_true.append(frac_energy_true)
particle_ids_all.append(particle_ids)
if e_frac_loss_return_particles:
return loss_E_frac, [loss_E_frac_true, particle_ids_all, reco_count, non_reco_count, total_count]
loss_E_frac = torch.mean(torch.stack(loss_E_frac, dim=0))
loss_E_frac_true = torch.mean(torch.stack(loss_E_frac_true, dim=0))
return loss_E_frac, loss_E_frac_true
def get_clustering(index_alpha_i, X, betas, td=0.7):
n_points = betas.size(0)
unassigned = torch.arange(n_points).to(betas.device)
clustering = -1 * torch.ones(n_points, dtype=torch.long)
counter = 0
for index_condpoint in index_alpha_i:
d = torch.norm(X[unassigned] - X[index_condpoint], dim=-1)
assigned_to_this_condpoint = unassigned[d < td]
clustering[assigned_to_this_condpoint] = counter
unassigned = unassigned[~(d < td)]
counter = counter + 1
counter = 0
for index_condpoint in index_alpha_i:
clustering[index_condpoint] = counter
counter = counter + 1
return clustering
|