jetclustering / src /layers /inference_oc.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
40.1 kB
import dgl
import torch
import os
from alembic.command import current
from sklearn.cluster import DBSCAN, HDBSCAN
from torch_scatter import scatter_max, scatter_add, scatter_mean
import numpy as np
from src.dataset.functions_data import CachedIndexList, spherical_to_cartesian
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
import pandas as pd
import wandb
from src.utils.inference.per_particle_metrics import plot_event
import random
import string
def generate_random_string(length):
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for i in range(length))
def create_and_store_graph_output(
batch_g,
model_output,
y,
local_rank,
step,
epoch,
path_save,
store=False,
predict=False,
tracking=False,
e_corr=None,
shap_vals=None,
ec_x=None, # ec_x: "global" features (what gets inputted into the final deep neural network head) for energy correction
tracks=False,
store_epoch=False,
total_number_events=0,
pred_pos=None,
pred_ref_pt=None,
use_gt_clusters=False,
pids_neutral=None,
pids_charged=None,
pred_pid=None,
pred_xyz_track=None,
number_of_fakes=None
):
number_of_showers_total = 0
number_of_showers_total1 = 0
number_of_fake_showers_total1 = 0
batch_g.ndata["coords"] = model_output[:, 0:3]
batch_g.ndata["beta"] = model_output[:, 3]
if not tracking:
if e_corr is None:
batch_g.ndata["correction"] = model_output[:, 4]
graphs = dgl.unbatch(batch_g)
batch_id = y.batch_number.view(-1) # y[:, -1].view(-1)
df_list = []
df_list1 = []
df_list_pandora = []
total_number_candidates = 0
for i in range(0, len(graphs)):
mask = batch_id == i
dic = {}
dic["graph"] = graphs[i]
y1 = y.copy()
y1.mask(mask)
dic["part_true"] = y1 # y[mask]
X = dic["graph"].ndata["coords"]
# if shap_vals is not None:
# dic["shap_values"] = shap_vals
# if ec_x is not None:
# dic["ec_x"] = ec_x ## ? No mask ?!?
if predict:
labels_clustering = clustering_obtain_labels(
X, dic["graph"].ndata["beta"].view(-1), model_output.device
)
if use_gt_clusters:
labels_hdb = dic["graph"].ndata["particle_number"].type(torch.int64)
else:
labels_hdb = hfdb_obtain_labels(X, model_output.device)
num_clusters = len(labels_hdb.unique())
#if labels_hdb.min() == 0 and labels_hdb.sum() == 0:
# labels_hdb += 1 # Quick hack
# raise Exception("!!!! Labels==0 !!!!")
if predict:
labels_pandora = get_labels_pandora(tracks, dic, model_output.device)
num_clusters_pandora = len(labels_pandora.unique())
particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
#current_number_candidates = num_clusters
#pred_pos_batch = pred_pos[total_number_candidates:total_number_candidates+current_number_candidates]
#pred_ref_pt_batch = pred_ref_pt[total_number_candidates:total_number_candidates+current_number_candidates]
#pred_pid_batch = pred_pid[total_number_candidates:total_number_candidates+current_number_candidates]
#e_corr_batch = e_corr[total_number_candidates:total_number_candidates+current_number_candidates]
"""if predict:
shower_p_unique = torch.unique(labels_clustering)
shower_p_unique, row_ind, col_ind, i_m_w, iou_m_c = match_showers(
labels_clustering,
dic,
particle_ids,
model_output,
local_rank,
i,
path_save,
tracks=tracks,
)"""
shower_p_unique_hdb, row_ind_hdb, col_ind_hdb, i_m_w_hdb, iou_m = match_showers(
labels_hdb,
dic,
particle_ids,
model_output,
local_rank,
i,
path_save,
tracks=tracks,
hdbscan=True,
)
if predict:
(
shower_p_unique_pandora,
row_ind_pandora,
col_ind_pandora,
i_m_w_pandora,
iou_m_pandora,
) = match_showers(
labels_pandora,
dic,
particle_ids,
model_output,
local_rank,
i,
path_save,
pandora=True,
tracks=tracks,
)
# # if len(row_ind_hdb) < len(dic["part_true"]):
# print(len(row_ind_hdb), len(dic["part_true"]))
# print("storing event", local_rank, step, i)
# path_graphs_all_comparing = os.path.join(path_save, "graphs_all_comparing")
# if not os.path.exists(path_graphs_all_comparing):
# os.makedirs(path_graphs_all_comparing)
'''torch.save(
dic,
path_save
+ "/graphs_all_comparing_Gregor/"
+ str(local_rank)
+ "_"
+ str(step)
+ "_"
+ str(i)
+ ".pt",
)'''
# torch.save(
# dic,
# path_save
# + "/graphs/"
# + str(local_rank)
# + "_"
# + str(step)
# + "_"
# + str(i)
# + ".pt",
# )
if len(shower_p_unique_hdb) > 1:
# df_event, number_of_showers_total = generate_showers_data_frame(
# labels_clustering,
# labels_clustering,
# dic,
# shower_p_unique,
# particle_ids,
# row_ind,
# col_ind,
# i_m_w,
# e_corr=e_corr,
# number_of_showers_total=number_of_showers_total,
# step=step,
# number_in_batch=i,
# tracks=tracks,
# )
# if pred_pos is not None:
# Apply temporary correction
import math
# phi = math.atan2(pred_pos[:, 1], pred_pos[:, 0])
# phi = torch.atan2(pred_pos[:, 1], pred_pos[:, 0])
# theta = torch.acos(pred_pos[:, 2] / torch.norm(pred_pos, dim=1))
# pred_pos = spherical_to_cartesian(theta, phi, torch.norm(pred_pos, dim=1), normalized=True)
# pred_pos= pred_pos.to(model_output.device)
df_event1, number_of_showers_total1, number_of_fake_showers_total1 = generate_showers_data_frame(
labels_hdb,
dic,
shower_p_unique_hdb,
particle_ids,
row_ind_hdb,
col_ind_hdb,
i_m_w_hdb,
e_corr=e_corr,
number_of_showers_total=number_of_showers_total1,
step=step,
number_in_batch=total_number_events,
tracks=tracks,
ec_x=ec_x,
shap_vals=shap_vals,
pred_pos=pred_pos,
pred_ref_pt=pred_ref_pt,
pred_pid=pred_pid,
save_plots_to_folder=path_save + "/ML_Model_evt_plots_debugging",
number_of_fakes=number_of_fakes,
number_of_fake_showers_total=number_of_fake_showers_total1,
)
if len(df_event1) > 1:
df_list1.append(df_event1)
if predict:
df_event_pandora = generate_showers_data_frame(
labels_pandora,
dic,
shower_p_unique_pandora,
particle_ids,
row_ind_pandora,
col_ind_pandora,
i_m_w_pandora,
pandora=True,
tracking=tracking,
step=step,
number_in_batch=total_number_events,
tracks=tracks,
save_plots_to_folder=path_save + "/Pandora_evt_plots_debugging",
)
if df_event_pandora is not None and type(df_event_pandora) is not tuple:
df_list_pandora.append(df_event_pandora)
total_number_events = total_number_events + 1
# print("number of showers total", number_of_showers_total)
# number_of_showers_total = number_of_showers_total + len(shower_p_unique_hdb)
# print("number of showers total", number_of_showers_total)
df_batch1 = pd.concat(df_list1)
if predict:
df_batch_pandora = pd.concat(df_list_pandora)
else:
df_batch = []
df_batch_pandora = []
#
if store:
store_at_batch_end(
path_save,
df_batch1,
df_batch_pandora,
# df_batch,
local_rank,
step,
epoch,
predict=predict,
store=store_epoch,
)
if predict:
return df_batch_pandora, df_batch1, total_number_events
else:
return df_batch1
def store_at_batch_end(
path_save,
df_batch1,
df_batch_pandora,
# df_batch,
local_rank=0,
step=0,
epoch=None,
predict=False,
store=False,
):
if predict:
path_save_ = (
path_save
+ "/"
+ str(local_rank)
+ "_"
+ str(step)
+ "_"
+ str(epoch)
+ ".pt"
)
# if store and predict:
# df_batch.to_pickle(path_save_)
# log_efficiency(df_batch, clustering=True)
path_save_ = (
path_save
+ "/"
+ str(local_rank)
+ "_"
+ str(step)
+ "_"
+ str(epoch)
+ "_hdbscan.pt"
)
if store and predict:
df_batch1.to_pickle(path_save_)
if predict:
path_save_pandora = (
path_save
+ "/"
+ str(local_rank)
+ "_"
+ str(step)
+ "_"
+ str(epoch)
+ "_pandora.pt"
)
if store and predict:
df_batch_pandora.to_pickle(path_save_pandora)
log_efficiency(df_batch1)
if predict:
log_efficiency(df_batch_pandora, pandora=True)
def log_efficiency(df, pandora=False, clustering=False):
mask = ~np.isnan(df["reco_showers_E"])
eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len(
df["pred_showers_E"][mask].values
)
if pandora:
wandb.log({"efficiency validation pandora": eff})
elif clustering:
wandb.log({"efficiency validation clustering": eff})
else:
wandb.log({"efficiency validation": eff})
def generate_showers_data_frame(
labels,
dic,
shower_p_unique,
particle_ids,
row_ind,
col_ind,
i_m_w,
pandora=False,
tracking=False,
e_corr=None,
number_of_showers_total=None,
step=0,
number_in_batch=0,
tracks=False,
shap_vals=None,
ec_x=None,
pred_pos=None,
pred_pid=None,
save_plots_to_folder="",
pred_ref_pt=None,
number_of_fake_showers_total=None,
number_of_fakes=None
):
shap = shap_vals is not None
e_pred_showers = scatter_add(dic["graph"].ndata["e_hits"].view(-1), labels)
if pandora:
e_pred_showers_cali = scatter_mean(
dic["graph"].ndata["pandora_cluster_energy"].view(-1), labels
)
e_pred_showers_pfo = scatter_mean(
dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels
)
# px_pred_pfo = scatter_mean(dic["graph"].ndata["hit_px"], labels)
# py_pred_pfo = scatter_mean(dic["graph"].ndata["hit_py"], labels)
# pz_pred_pfo = scatter_mean(dic["graph"].ndata["hit_pz"], labels)
# p_pred_pfo = scatter_mean(dic["graph"].ndata["pos_pxpypz"], labels) # FIX THIS: the shape of pos_pxpypz is [-1, 3]
calc_pandora_momentum = "pandora_momentum" in dic["graph"].ndata
if calc_pandora_momentum:
px_pred_pfo = scatter_mean(
dic["graph"].ndata["pandora_momentum"][:, 0], labels
)
py_pred_pfo = scatter_mean(
dic["graph"].ndata["pandora_momentum"][:, 1], labels
)
pz_pred_pfo = scatter_mean(
dic["graph"].ndata["pandora_momentum"][:, 2], labels
)
ref_pt_px_pred_pfo = scatter_mean(
dic["graph"].ndata["pandora_reference_point"][:, 0], labels
)
ref_pt_py_pred_pfo = scatter_mean(
dic["graph"].ndata["pandora_reference_point"][:, 1], labels
)
ref_pt_pz_pred_pfo = scatter_mean(
dic["graph"].ndata["pandora_reference_point"][:, 2], labels
)
pandora_pid = scatter_mean(
dic["graph"].ndata["pandora_pid"], labels
)
ref_pt_pred_pfo = torch.stack(
(ref_pt_px_pred_pfo, ref_pt_py_pred_pfo, ref_pt_pz_pred_pfo), dim=1
)
# p_pred_pandora = scatter_mean(dic["graph"].ndata["pandora_momentum"], labels)
p_pred_pandora = torch.stack((px_pred_pfo, py_pred_pfo, pz_pred_pfo), dim=1)
p_size_pandora = torch.norm(p_pred_pandora, dim=1)
pxyz_pred_pfo = (
p_pred_pandora # / torch.norm(p_pred_pandora, dim=1).view(-1, 1)
)
else:
if e_corr is None:
corrections_per_shower = get_correction_per_shower(labels, dic)
e_pred_showers_cali = e_pred_showers * corrections_per_shower
else:
corrections_per_shower = e_corr.view(-1)
if number_of_fakes > 0:
corrections_per_shower_fakes = corrections_per_shower[-number_of_fakes:]
corrections_per_shower = corrections_per_shower[:-number_of_fakes]
e_reco_showers = scatter_add(
dic["graph"].ndata["e_hits"].view(-1),
dic["graph"].ndata["particle_number"].long(),
)
row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long()
col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long()
if torch.sum(particle_ids == 0) > 0:
# particle id can be 0 because there is noise
# then row ind 0 in any case corresponds to particle 1.
# if there is particle_id 0 then row_ind should be +1?
row_ind_ = row_ind - 1
else:
# if there is no zero then index 0 corresponds to particle 1.
row_ind_ = row_ind
pred_showers = shower_p_unique
energy_t = (
dic["part_true"].E_corrected.view(-1).to(e_pred_showers.device)
) # dic["part_true"][:, 3].to(e_pred_showers.device)
vertex = dic["part_true"].vertex.to(e_pred_showers.device)
pos_t = dic["part_true"].coord.to(e_pred_showers.device)
pid_t = dic["part_true"].pid.to(e_pred_showers.device)
is_track_per_shower = scatter_add((dic["graph"].ndata["hit_type"] == 1), labels).int()
is_track = torch.zeros(energy_t.shape).to(e_pred_showers.device)
if shap:
matched_shap_vals = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * (
torch.nan
)
matched_shap_vals = matched_shap_vals.numpy()
matched_ec_x = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * (torch.nan)
matched_ec_x = matched_ec_x.numpy()
index_matches = col_ind + 1
index_matches = index_matches.to(e_pred_showers.device).long()
matched_es = torch.zeros_like(energy_t) * (torch.nan)
matched_positions = torch.zeros((energy_t.shape[0], 3)) * (torch.nan)
matched_positions = matched_positions.to(e_pred_showers.device)
matched_ref_pt = torch.zeros((energy_t.shape[0], 3)) * (torch.nan)
matched_ref_pt = matched_ref_pt.to(e_pred_showers.device)
matched_pid = torch.zeros_like(energy_t) * (torch.nan)
matched_pid = matched_pid.to(e_pred_showers.device).long()
matched_positions_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan)
matched_positions_pfo = matched_positions_pfo.to(e_pred_showers.device)
matched_pandora_pid = (torch.zeros((energy_t.shape[0])) * (torch.nan)).to(e_pred_showers.device)
matched_ref_pts_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan)
matched_ref_pts_pfo = matched_ref_pts_pfo.to(e_pred_showers.device)
matched_es = matched_es.to(e_pred_showers.device)
matched_es[row_ind_] = e_pred_showers[index_matches]
if pandora:
matched_es_cali = matched_es.clone()
matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches]
matched_es_cali_pfo = matched_es.clone()
matched_es_cali_pfo[row_ind_] = e_pred_showers_pfo[index_matches]
matched_pandora_pid[row_ind_] = pandora_pid[index_matches]
if calc_pandora_momentum:
matched_positions_pfo[row_ind_] = pxyz_pred_pfo[index_matches]
matched_ref_pts_pfo[row_ind_] = ref_pt_pred_pfo[index_matches]
is_track[row_ind_] = is_track_per_shower[index_matches].float()
else:
if e_corr is None:
matched_es_cali = matched_es.clone()
matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches]
calibration_per_shower = matched_es.clone()
calibration_per_shower[row_ind_] = corrections_per_shower[index_matches]
else:
matched_es_cali = matched_es.clone()
number_of_showers = e_pred_showers[index_matches].shape[0] # DOESN'T INCLUDE THE FAKE SHOWERS
#number_of_fake_showers = e_pred_showers.shape[0] - number_of_showers
matched_es_cali[row_ind_] = (
corrections_per_shower[
number_of_showers_total : number_of_showers_total
+ number_of_showers
]
#* e_pred_showers[index_matches]
)
# if len(row_ind) and len(index_matches):
# assert row_ind.max() < len(is_track)
# assert index_matches.max() < len(is_track_per_shower)
is_track[row_ind_] = is_track_per_shower[index_matches].float()
if pred_pos is not None:
matched_positions[row_ind_] = pred_pos[number_of_showers_total : number_of_showers_total
+ number_of_showers]
matched_ref_pt[row_ind_] = pred_ref_pt[number_of_showers_total : number_of_showers_total + number_of_showers]
matched_pid[row_ind_] = pred_pid[number_of_showers_total : number_of_showers_total + number_of_showers]
if shap:
matched_shap_vals[row_ind_.cpu()] = shap_vals[index_matches.cpu()]
matched_ec_x[row_ind_.cpu()] = ec_x[index_matches.cpu()]
calibration_per_shower = matched_es.clone()
calibration_per_shower[row_ind_] = corrections_per_shower[
number_of_showers_total : number_of_showers_total + number_of_showers
]
number_of_showers_total = number_of_showers_total + number_of_showers
intersection_E = torch.zeros_like(energy_t) * (torch.nan)
if len(col_ind) > 0:
ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind, dic)
intersection_E[row_ind_] = ie_e.to(e_pred_showers.device)
pred_showers[index_matches] = -1
pred_showers[
0
] = (
-1
) # This takes into account that the class 0 for pandora and for dbscan is noise
mask = pred_showers != -1
number_of_fake_showers = mask.sum()
fakes_in_event = mask.sum()
fake_showers_e = e_pred_showers[mask]
if e_corr is None or pandora:
fake_showers_e_cali = e_pred_showers_cali[mask]
# fakes_positions = dic["graph"].ndata["coords"][mask]
else:
#fake_showers_e_cali = corrections_per_shower[number_of_showers_total:number_of_showers_total+number_of_showers][mask]# * (torch.nan)
#fakes_positions = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)
#fake_showers_e_cali = fake_showers_e
#fakes_pid_pred = torch.zeros((fake_showers_e.shape[0])) * (torch.nan) # just for now for debugigng
#fakes_positions = fakes_positions.to(e_pred_showers.device)
#fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device)
fakes_positions = pred_pos[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers]
fake_showers_e_cali = e_corr[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers]
fakes_pid_pred = pred_pid[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers]
fake_showers_e_reco = e_reco_showers[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers]
fakes_positions = fakes_positions.to(e_pred_showers.device)
fake_showers_e_cali = fake_showers_e_cali.to(e_pred_showers.device)
fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device)
fake_showers_e_reco = fake_showers_e_reco.to(e_pred_showers.device)
#fakes_pid_pred = pred_pid[number_of_showers_total:number_of_showers_total+number_of_showers][mask]
#fakes_positions = fakes_positions.to(e_pred_showers.device)
if pandora:
fake_pandora_pid = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device)
fake_pandora_pid = pandora_pid[mask]
if calc_pandora_momentum:
fake_positions_pfo = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)
fake_positions_pfo = fake_positions_pfo.to(e_pred_showers.device)
fake_positions_pfo = pxyz_pred_pfo[mask]
fakes_positions_ref = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device)
fakes_positions_ref = ref_pt_pred_pfo[mask]
if not pandora:
if e_corr is None:
fake_showers_e_cali_factor = corrections_per_shower[mask]
else:
fake_showers_e_cali_factor = fake_showers_e_cali
fake_showers_showers_e_truw = torch.zeros((fake_showers_e.shape[0])) * (
torch.nan
)
fake_showers_vertex = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)
fakes_is_track = (torch.zeros((fake_showers_e.shape[0])) * (torch.nan)).to(e_pred_showers.device)
fakes_is_track = is_track_per_shower[mask]
fakes_positions_t = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)
if not pandora:
number_of_fake_showers_total = number_of_fake_showers_total + number_of_fake_showers
"""if shap:
fake_showers_shap_vals = torch.zeros((fake_showers_e.shape[0], shap_vals_t.shape[1])) * (
torch.nan
)
fake_showers_ec_x_t = torch.zeros((fake_showers_e.shape[0], ec_x_t.shape[1])) * (
torch.nan
)
#fake_showers_shap_vals = fake_showers_shap_vals.to(e_pred_showers.device)
#fake_showers_ec_x_t = fake_showers_ec_x_t.to(e_pred_showers.device)
shap_vals_t = torch.cat((torch.tensor(shap_vals_t), fake_showers_shap_vals), dim=0)
ec_x_t = torch.cat((torch.tensor(ec_x_t), fake_showers_ec_x_t), dim=0)
"""
fake_showers_showers_e_truw = fake_showers_showers_e_truw.to(
e_pred_showers.device
)
fakes_positions_t = fakes_positions_t.to(e_pred_showers.device)
fake_showers_vertex = fake_showers_vertex.to(e_pred_showers.device)
energy_t = torch.cat(
(energy_t, fake_showers_showers_e_truw),
dim=0,
)
vertex = torch.cat((vertex, fake_showers_vertex), dim=0)
pid_t = torch.cat(
(pid_t.view(-1), fake_showers_showers_e_truw),
dim=0,
)
pos_t = torch.cat(
(pos_t, fakes_positions_t),
dim=0,
)
e_reco = torch.cat((e_reco_showers[1:], fake_showers_showers_e_truw), dim=0)
e_pred = torch.cat((matched_es, fake_showers_e), dim=0)
e_pred_cali = torch.cat((matched_es_cali, fake_showers_e_cali), dim=0)
if pred_pos is not None:
e_pred_pos = torch.cat((matched_positions, fakes_positions), dim=0)
e_pred_pid = torch.cat((matched_pid, fakes_pid_pred), dim=0)
e_pred_ref_pt = torch.cat((matched_ref_pt, fakes_positions), dim=0)
if pandora:
e_pred_cali_pfo = torch.cat(
(matched_es_cali_pfo, fake_showers_e_cali), dim=0
)
positions_pfo = torch.cat((matched_positions_pfo, fake_positions_pfo), dim=0)
pandora_pid = torch.cat((matched_pandora_pid, fake_pandora_pid), dim=0)
ref_pts_pfo = torch.cat((matched_ref_pts_pfo, fakes_positions_ref), dim=0)
if not pandora:
calibration_factor = torch.cat(
(calibration_per_shower, fake_showers_e_cali_factor), dim=0
)
if shap:
# pad
matched_shap_vals = torch.cat(
(
torch.tensor(matched_shap_vals),
torch.zeros((fake_showers_e.shape[0], shap_vals.shape[1])),
),
dim=0,
)
matched_ec_x = torch.cat(
(
torch.tensor(matched_ec_x),
torch.zeros((fake_showers_e.shape[0], ec_x.shape[1])),
),
dim=0,
)
e_pred_t = torch.cat(
(
intersection_E,
torch.zeros_like(fake_showers_e) * (torch.nan),
),
dim=0,
)
# e_pred_t_pandora = torch.cat(
# (
# intersection_E,
# torch.zeros_like(fake_showers_e) * (-200),
# torch.zeros_like(fake_showers_e_pandora) * (-100),
# ),
# dim=0,
# )
is_track = torch.cat((is_track, fakes_is_track.to(is_track.device)), dim=0)
if pandora:
d = {
"true_showers_E": energy_t.detach().cpu(),
"reco_showers_E": e_reco.detach().cpu(),
"pred_showers_E": e_pred.detach().cpu(),
"e_pred_and_truth": e_pred_t.detach().cpu(),
"pandora_calibrated_E": e_pred_cali.detach().cpu(),
"pandora_calibrated_pfo": e_pred_cali_pfo.detach().cpu(),
"pandora_calibrated_pos": positions_pfo.detach().cpu().tolist(),
"pandora_ref_pt": ref_pts_pfo.detach().cpu().tolist(),
"pid": pid_t.detach().cpu(),
"pandora_pid":pandora_pid.detach().cpu(),
"step": torch.ones_like(energy_t.detach().cpu()) * step,
"number_batch": torch.ones_like(energy_t.detach().cpu())
* number_in_batch,
"is_track_in_cluster": is_track.detach().cpu(),
"vertex": vertex.detach().cpu().tolist()
}
else:
d = {
"true_showers_E": energy_t.detach().cpu(),
"reco_showers_E": e_reco.detach().cpu(),
"pred_showers_E": e_pred.detach().cpu(),
"e_pred_and_truth": e_pred_t.detach().cpu(),
"pid": pid_t.detach().cpu(),
"calibration_factor": calibration_factor.detach().cpu(),
"calibrated_E": e_pred_cali.detach().cpu(),
"step": torch.ones_like(energy_t.detach().cpu()) * step,
"number_batch": torch.ones_like(energy_t.detach().cpu())
* number_in_batch,
"is_track_in_cluster": is_track.detach().cpu(),
"vertex": vertex.detach().cpu().tolist()
}
if pred_pos is not None:
pred_pos1 = e_pred_pos.detach().cpu()
pred_pid1 = e_pred_pid.detach().cpu()
pred_ref_pt1 = e_pred_ref_pt.detach().cpu()
d["pred_pos_matched"] = (
pred_pos1.tolist()
) # Otherwise it doesn't work nicely with Pandas DataFrames
d["pred_pid_matched"] = pred_pid1.tolist()
d["pred_ref_pt_matched"] = pred_ref_pt1.tolist()
"""if shap:
print("Adding ec_x and shap_values to the DataFrame")
d["ec_x"] = ec_x_t
d["shap_values"] = shap_vals_t"""
if shap:
d["shap_values"] = matched_shap_vals.tolist()
d["ec_x"] = matched_ec_x.tolist()
d["true_pos"] = pos_t.detach().cpu().tolist()
df = pd.DataFrame(data=d)
if save_plots_to_folder:
event_numbers = np.unique(df.number_batch)
for evt in event_numbers:
if len(df[df.number_batch == evt]):
# Random string
rndstr = generate_random_string(5)
plot_event(
df[df.number_batch == evt],
pandora,
save_plots_to_folder + str(evt) + rndstr,
graph=dic["graph"].to("cpu"),
y=dic["part_true"],
labels=dic["graph"].ndata["particle_number"].long(),
is_track_in_cluster=df.is_track_in_cluster
)
'''plot_event(
df[df.number_batch == evt],
pandora,
save_plots_to_folder + "_CLUSTERING_" + str(evt) + rndstr,
graph=dic["graph"].to("cpu"),
y=dic["part_true"],
labels=labels.detach().cpu(),
is_track_in_cluster=df.is_track_in_cluster
)'''
if number_of_showers_total is None:
return df
else:
return df, number_of_showers_total, number_of_fake_showers_total
else:
return [], 0, 0
def get_correction_per_shower(labels, dic):
unique_labels = torch.unique(labels)
list_corr = []
for ii, pred_label in enumerate(unique_labels):
if ii == 0:
if pred_label != 0:
list_corr.append(dic["graph"].ndata["correction"][0].view(-1) * 0)
mask = labels == pred_label
corrections_E_label = dic["graph"].ndata["correction"][mask]
betas_label_indmax = torch.argmax(dic["graph"].ndata["beta"][mask])
list_corr.append(corrections_E_label[betas_label_indmax].view(-1))
corrections = torch.cat(list_corr, dim=0)
return corrections
def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits):
len_pred_showers = len(shower_p_unique)
intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to(
shower_p_unique.device
)
intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to(
shower_p_unique.device
)
for index, id in enumerate(particle_ids):
counts = torch.zeros_like(labels)
mask_p = dic["graph"].ndata["particle_number"] == id
h_hits = e_hits.clone()
counts[mask_p] = 1
h_hits[~mask_p] = 0
intersection_matrix[:, index] = scatter_add(counts, labels)
intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device))
return intersection_matrix, intersection_matrix_w
def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic):
len_pred_showers = len(shower_p_unique)
union_matrix = torch.zeros((len_pred_showers, len(particle_ids)))
for index, id in enumerate(particle_ids):
counts = torch.zeros_like(labels)
mask_p = dic["graph"].ndata["particle_number"] == id
for index_pred, id_pred in enumerate(shower_p_unique):
mask_pred_p = labels == id_pred
mask_union = mask_pred_p + mask_p
union_matrix[index_pred, index] = torch.sum(mask_union)
return union_matrix
def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.7, td=0.03):
"""
Returns a clustering of hits -> cluster_index, based on the GravNet model
output (predicted betas and cluster space coordinates) and the clustering
parameters tbeta and td.
Takes torch.Tensors as input.
"""
n_points = betas.size(0)
select_condpoints = betas > tbeta
# Get indices passing the threshold
indices_condpoints = select_condpoints.nonzero()
# Order them by decreasing beta value
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()]
# Assign points to condensation points
# Only assign previously unassigned points (no overwriting)
# Points unassigned at the end are bkg (-1)
unassigned = torch.arange(n_points).to(betas.device)
clustering = -1 * torch.ones(n_points, dtype=torch.long).to(betas.device)
while len(indices_condpoints) > 0 and len(unassigned) > 0:
index_condpoint = indices_condpoints[0]
d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1)
assigned_to_this_condpoint = unassigned[d < td]
clustering[assigned_to_this_condpoint] = index_condpoint[0]
unassigned = unassigned[~(d < td)]
# calculate indices_codpoints again
indices_condpoints = find_condpoints(betas, unassigned, tbeta)
return clustering
def find_condpoints(betas, unassigned, tbeta):
n_points = betas.size(0)
select_condpoints = betas > tbeta
device = betas.device
mask_unassigned = torch.zeros(n_points).to(device)
mask_unassigned[unassigned] = True
select_condpoints = mask_unassigned.to(bool) * select_condpoints
# Get indices passing the threshold
indices_condpoints = select_condpoints.nonzero()
# Order them by decreasing beta value
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()]
return indices_condpoints
def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind, dic):
list_intersection_E = []
# intersection_matrix_w = intersection_matrix_w
particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
if torch.sum(particle_ids == 0) > 0:
# removing also the MC particle corresponding to noise
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, 1:], 1, 0)
row_ind = row_ind - 1
else:
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0)
for i in range(0, len(col_ind)):
list_intersection_E.append(
intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1)
)
if len(list_intersection_E) > 0:
return torch.cat(list_intersection_E, dim=0)
else:
return 0
def plot_iou_matrix(iou_matrix, image_path, hdbscan=False):
iou_matrix = torch.transpose(iou_matrix[1:, :], 1, 0)
fig, ax = plt.subplots()
iou_matrix = iou_matrix.detach().cpu().numpy()
ax.matshow(iou_matrix, cmap=plt.cm.Blues)
for i in range(0, iou_matrix.shape[1]):
for j in range(0, iou_matrix.shape[0]):
c = np.round(iou_matrix[j, i], 1)
ax.text(i, j, str(c), va="center", ha="center")
fig.savefig(image_path, bbox_inches="tight")
if hdbscan:
wandb.log({"iou_matrix_hdbscan": wandb.Image(image_path)})
else:
wandb.log({"iou_matrix": wandb.Image(image_path)})
def match_showers(
labels,
dic,
particle_ids,
model_output,
local_rank,
i,
path_save,
pandora=False,
tracks=False,
hdbscan=False,
):
iou_threshold = 0.25
shower_p_unique = torch.unique(labels)
if torch.sum(labels == 0) == 0:
shower_p_unique = torch.cat(
(
torch.Tensor([0]).to(shower_p_unique.device).view(-1),
shower_p_unique.view(-1),
),
dim=0,
)
e_hits = dic["graph"].ndata["e_hits"].view(-1)
i_m, i_m_w = obtain_intersection_matrix(
shower_p_unique, particle_ids, labels, dic, e_hits
)
i_m = i_m.to(model_output.device)
i_m_w = i_m_w.to(model_output.device)
u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic)
u_m = u_m.to(model_output.device)
iou_matrix = i_m / u_m
if torch.sum(particle_ids == 0) > 0:
# removing also the MC particle corresponding to noise
iou_matrix_num = (
torch.transpose(iou_matrix[1:, 1:], 1, 0).clone().detach().cpu().numpy()
)
else:
iou_matrix_num = (
torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy()
)
iou_matrix_num[iou_matrix_num < iou_threshold] = 0
row_ind, col_ind = linear_sum_assignment(-iou_matrix_num)
# Next three lines remove solutions where there is a shower that is not associated and iou it's zero (or less than threshold)
mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0
row_ind = row_ind[mask_matching_matrix]
col_ind = col_ind[mask_matching_matrix]
if torch.sum(particle_ids == 0) > 0:
row_ind = row_ind + 1
if i == 0 and local_rank == 0:
if path_save is not None:
if pandora:
image_path = path_save + "/example_1_clustering_pandora.png"
else:
image_path = path_save + "/example_1_clustering.png"
# plot_iou_matrix(iou_matrix, image_path, hdbscan)
# row_ind are particles that are matched and col_ind the ind of preds they are matched to
return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix
def clustering_obtain_labels(X, betas, device):
clustering = get_clustering(betas, X)
map_from = list(np.unique(clustering.detach().cpu()))
cluster_id = map(lambda x: map_from.index(x), clustering.detach().cpu())
clustering_ordered = torch.Tensor(list(cluster_id)).long()
if torch.unique(clustering)[0] != -1:
clustering = clustering_ordered + 1
else:
clustering = clustering_ordered
clustering = torch.Tensor(clustering.view(-1)).long().to(device)
return clustering
def hfdb_obtain_labels(X, device, eps=0.1):
hdb = HDBSCAN(min_cluster_size=8, min_samples=8, cluster_selection_epsilon=eps).fit(
X.detach().cpu()
)
labels_hdb = hdb.labels_ + 1
labels_hdb = np.reshape(labels_hdb, (-1))
labels_hdb = torch.Tensor(labels_hdb).long().to(device)
return labels_hdb
def dbscan_obtain_labels(X, device):
distance_scale = (
(torch.min(torch.abs(torch.min(X, dim=0)[0] - torch.max(X, dim=0)[0])) / 30)
.view(-1)
.detach()
.cpu()
.numpy()[0]
)
db = DBSCAN(eps=distance_scale, min_samples=15).fit(X.detach().cpu())
# DBSCAN has clustering labels -1,0,.., our cluster 0 is noise so we add 1
labels = db.labels_ + 1
labels = np.reshape(labels, (-1))
labels = torch.Tensor(labels).long().to(device)
return labels
class CachedIndexList:
def __init__(self, lst):
self.lst = lst
self.cache = {}
def index(self, value):
if value in self.cache:
return self.cache[value]
else:
idx = self.lst.index(value)
self.cache[value] = idx
return idx
def get_labels_pandora(tracks, dic, device):
if tracks:
labels_pandora = dic["graph"].ndata["pandora_pfo"].long()
else:
labels_pandora = dic["graph"].ndata["pandora_cluster"].long()
labels_pandora = labels_pandora + 1
map_from = list(np.unique(labels_pandora.detach().cpu()))
map_from = CachedIndexList(map_from)
cluster_id = map(lambda x: map_from.index(x), labels_pandora.detach().cpu().numpy())
labels_pandora = torch.Tensor(list(cluster_id)).long().to(device)
return labels_pandora