Spaces:
Sleeping
Sleeping
import dgl | |
import torch | |
import os | |
from alembic.command import current | |
from sklearn.cluster import DBSCAN, HDBSCAN | |
from torch_scatter import scatter_max, scatter_add, scatter_mean | |
import numpy as np | |
from src.dataset.functions_data import CachedIndexList, spherical_to_cartesian | |
import matplotlib.pyplot as plt | |
from scipy.optimize import linear_sum_assignment | |
import pandas as pd | |
import wandb | |
from src.utils.inference.per_particle_metrics import plot_event | |
import random | |
import string | |
def generate_random_string(length): | |
letters = string.ascii_letters + string.digits | |
return "".join(random.choice(letters) for i in range(length)) | |
def create_and_store_graph_output( | |
batch_g, | |
model_output, | |
y, | |
local_rank, | |
step, | |
epoch, | |
path_save, | |
store=False, | |
predict=False, | |
tracking=False, | |
e_corr=None, | |
shap_vals=None, | |
ec_x=None, # ec_x: "global" features (what gets inputted into the final deep neural network head) for energy correction | |
tracks=False, | |
store_epoch=False, | |
total_number_events=0, | |
pred_pos=None, | |
pred_ref_pt=None, | |
use_gt_clusters=False, | |
pids_neutral=None, | |
pids_charged=None, | |
pred_pid=None, | |
pred_xyz_track=None, | |
number_of_fakes=None | |
): | |
number_of_showers_total = 0 | |
number_of_showers_total1 = 0 | |
number_of_fake_showers_total1 = 0 | |
batch_g.ndata["coords"] = model_output[:, 0:3] | |
batch_g.ndata["beta"] = model_output[:, 3] | |
if not tracking: | |
if e_corr is None: | |
batch_g.ndata["correction"] = model_output[:, 4] | |
graphs = dgl.unbatch(batch_g) | |
batch_id = y.batch_number.view(-1) # y[:, -1].view(-1) | |
df_list = [] | |
df_list1 = [] | |
df_list_pandora = [] | |
total_number_candidates = 0 | |
for i in range(0, len(graphs)): | |
mask = batch_id == i | |
dic = {} | |
dic["graph"] = graphs[i] | |
y1 = y.copy() | |
y1.mask(mask) | |
dic["part_true"] = y1 # y[mask] | |
X = dic["graph"].ndata["coords"] | |
# if shap_vals is not None: | |
# dic["shap_values"] = shap_vals | |
# if ec_x is not None: | |
# dic["ec_x"] = ec_x ## ? No mask ?!? | |
if predict: | |
labels_clustering = clustering_obtain_labels( | |
X, dic["graph"].ndata["beta"].view(-1), model_output.device | |
) | |
if use_gt_clusters: | |
labels_hdb = dic["graph"].ndata["particle_number"].type(torch.int64) | |
else: | |
labels_hdb = hfdb_obtain_labels(X, model_output.device) | |
num_clusters = len(labels_hdb.unique()) | |
#if labels_hdb.min() == 0 and labels_hdb.sum() == 0: | |
# labels_hdb += 1 # Quick hack | |
# raise Exception("!!!! Labels==0 !!!!") | |
if predict: | |
labels_pandora = get_labels_pandora(tracks, dic, model_output.device) | |
num_clusters_pandora = len(labels_pandora.unique()) | |
particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) | |
#current_number_candidates = num_clusters | |
#pred_pos_batch = pred_pos[total_number_candidates:total_number_candidates+current_number_candidates] | |
#pred_ref_pt_batch = pred_ref_pt[total_number_candidates:total_number_candidates+current_number_candidates] | |
#pred_pid_batch = pred_pid[total_number_candidates:total_number_candidates+current_number_candidates] | |
#e_corr_batch = e_corr[total_number_candidates:total_number_candidates+current_number_candidates] | |
"""if predict: | |
shower_p_unique = torch.unique(labels_clustering) | |
shower_p_unique, row_ind, col_ind, i_m_w, iou_m_c = match_showers( | |
labels_clustering, | |
dic, | |
particle_ids, | |
model_output, | |
local_rank, | |
i, | |
path_save, | |
tracks=tracks, | |
)""" | |
shower_p_unique_hdb, row_ind_hdb, col_ind_hdb, i_m_w_hdb, iou_m = match_showers( | |
labels_hdb, | |
dic, | |
particle_ids, | |
model_output, | |
local_rank, | |
i, | |
path_save, | |
tracks=tracks, | |
hdbscan=True, | |
) | |
if predict: | |
( | |
shower_p_unique_pandora, | |
row_ind_pandora, | |
col_ind_pandora, | |
i_m_w_pandora, | |
iou_m_pandora, | |
) = match_showers( | |
labels_pandora, | |
dic, | |
particle_ids, | |
model_output, | |
local_rank, | |
i, | |
path_save, | |
pandora=True, | |
tracks=tracks, | |
) | |
# # if len(row_ind_hdb) < len(dic["part_true"]): | |
# print(len(row_ind_hdb), len(dic["part_true"])) | |
# print("storing event", local_rank, step, i) | |
# path_graphs_all_comparing = os.path.join(path_save, "graphs_all_comparing") | |
# if not os.path.exists(path_graphs_all_comparing): | |
# os.makedirs(path_graphs_all_comparing) | |
'''torch.save( | |
dic, | |
path_save | |
+ "/graphs_all_comparing_Gregor/" | |
+ str(local_rank) | |
+ "_" | |
+ str(step) | |
+ "_" | |
+ str(i) | |
+ ".pt", | |
)''' | |
# torch.save( | |
# dic, | |
# path_save | |
# + "/graphs/" | |
# + str(local_rank) | |
# + "_" | |
# + str(step) | |
# + "_" | |
# + str(i) | |
# + ".pt", | |
# ) | |
if len(shower_p_unique_hdb) > 1: | |
# df_event, number_of_showers_total = generate_showers_data_frame( | |
# labels_clustering, | |
# labels_clustering, | |
# dic, | |
# shower_p_unique, | |
# particle_ids, | |
# row_ind, | |
# col_ind, | |
# i_m_w, | |
# e_corr=e_corr, | |
# number_of_showers_total=number_of_showers_total, | |
# step=step, | |
# number_in_batch=i, | |
# tracks=tracks, | |
# ) | |
# if pred_pos is not None: | |
# Apply temporary correction | |
import math | |
# phi = math.atan2(pred_pos[:, 1], pred_pos[:, 0]) | |
# phi = torch.atan2(pred_pos[:, 1], pred_pos[:, 0]) | |
# theta = torch.acos(pred_pos[:, 2] / torch.norm(pred_pos, dim=1)) | |
# pred_pos = spherical_to_cartesian(theta, phi, torch.norm(pred_pos, dim=1), normalized=True) | |
# pred_pos= pred_pos.to(model_output.device) | |
df_event1, number_of_showers_total1, number_of_fake_showers_total1 = generate_showers_data_frame( | |
labels_hdb, | |
dic, | |
shower_p_unique_hdb, | |
particle_ids, | |
row_ind_hdb, | |
col_ind_hdb, | |
i_m_w_hdb, | |
e_corr=e_corr, | |
number_of_showers_total=number_of_showers_total1, | |
step=step, | |
number_in_batch=total_number_events, | |
tracks=tracks, | |
ec_x=ec_x, | |
shap_vals=shap_vals, | |
pred_pos=pred_pos, | |
pred_ref_pt=pred_ref_pt, | |
pred_pid=pred_pid, | |
save_plots_to_folder=path_save + "/ML_Model_evt_plots_debugging", | |
number_of_fakes=number_of_fakes, | |
number_of_fake_showers_total=number_of_fake_showers_total1, | |
) | |
if len(df_event1) > 1: | |
df_list1.append(df_event1) | |
if predict: | |
df_event_pandora = generate_showers_data_frame( | |
labels_pandora, | |
dic, | |
shower_p_unique_pandora, | |
particle_ids, | |
row_ind_pandora, | |
col_ind_pandora, | |
i_m_w_pandora, | |
pandora=True, | |
tracking=tracking, | |
step=step, | |
number_in_batch=total_number_events, | |
tracks=tracks, | |
save_plots_to_folder=path_save + "/Pandora_evt_plots_debugging", | |
) | |
if df_event_pandora is not None and type(df_event_pandora) is not tuple: | |
df_list_pandora.append(df_event_pandora) | |
total_number_events = total_number_events + 1 | |
# print("number of showers total", number_of_showers_total) | |
# number_of_showers_total = number_of_showers_total + len(shower_p_unique_hdb) | |
# print("number of showers total", number_of_showers_total) | |
df_batch1 = pd.concat(df_list1) | |
if predict: | |
df_batch_pandora = pd.concat(df_list_pandora) | |
else: | |
df_batch = [] | |
df_batch_pandora = [] | |
# | |
if store: | |
store_at_batch_end( | |
path_save, | |
df_batch1, | |
df_batch_pandora, | |
# df_batch, | |
local_rank, | |
step, | |
epoch, | |
predict=predict, | |
store=store_epoch, | |
) | |
if predict: | |
return df_batch_pandora, df_batch1, total_number_events | |
else: | |
return df_batch1 | |
def store_at_batch_end( | |
path_save, | |
df_batch1, | |
df_batch_pandora, | |
# df_batch, | |
local_rank=0, | |
step=0, | |
epoch=None, | |
predict=False, | |
store=False, | |
): | |
if predict: | |
path_save_ = ( | |
path_save | |
+ "/" | |
+ str(local_rank) | |
+ "_" | |
+ str(step) | |
+ "_" | |
+ str(epoch) | |
+ ".pt" | |
) | |
# if store and predict: | |
# df_batch.to_pickle(path_save_) | |
# log_efficiency(df_batch, clustering=True) | |
path_save_ = ( | |
path_save | |
+ "/" | |
+ str(local_rank) | |
+ "_" | |
+ str(step) | |
+ "_" | |
+ str(epoch) | |
+ "_hdbscan.pt" | |
) | |
if store and predict: | |
df_batch1.to_pickle(path_save_) | |
if predict: | |
path_save_pandora = ( | |
path_save | |
+ "/" | |
+ str(local_rank) | |
+ "_" | |
+ str(step) | |
+ "_" | |
+ str(epoch) | |
+ "_pandora.pt" | |
) | |
if store and predict: | |
df_batch_pandora.to_pickle(path_save_pandora) | |
log_efficiency(df_batch1) | |
if predict: | |
log_efficiency(df_batch_pandora, pandora=True) | |
def log_efficiency(df, pandora=False, clustering=False): | |
mask = ~np.isnan(df["reco_showers_E"]) | |
eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len( | |
df["pred_showers_E"][mask].values | |
) | |
if pandora: | |
wandb.log({"efficiency validation pandora": eff}) | |
elif clustering: | |
wandb.log({"efficiency validation clustering": eff}) | |
else: | |
wandb.log({"efficiency validation": eff}) | |
def generate_showers_data_frame( | |
labels, | |
dic, | |
shower_p_unique, | |
particle_ids, | |
row_ind, | |
col_ind, | |
i_m_w, | |
pandora=False, | |
tracking=False, | |
e_corr=None, | |
number_of_showers_total=None, | |
step=0, | |
number_in_batch=0, | |
tracks=False, | |
shap_vals=None, | |
ec_x=None, | |
pred_pos=None, | |
pred_pid=None, | |
save_plots_to_folder="", | |
pred_ref_pt=None, | |
number_of_fake_showers_total=None, | |
number_of_fakes=None | |
): | |
shap = shap_vals is not None | |
e_pred_showers = scatter_add(dic["graph"].ndata["e_hits"].view(-1), labels) | |
if pandora: | |
e_pred_showers_cali = scatter_mean( | |
dic["graph"].ndata["pandora_cluster_energy"].view(-1), labels | |
) | |
e_pred_showers_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels | |
) | |
# px_pred_pfo = scatter_mean(dic["graph"].ndata["hit_px"], labels) | |
# py_pred_pfo = scatter_mean(dic["graph"].ndata["hit_py"], labels) | |
# pz_pred_pfo = scatter_mean(dic["graph"].ndata["hit_pz"], labels) | |
# p_pred_pfo = scatter_mean(dic["graph"].ndata["pos_pxpypz"], labels) # FIX THIS: the shape of pos_pxpypz is [-1, 3] | |
calc_pandora_momentum = "pandora_momentum" in dic["graph"].ndata | |
if calc_pandora_momentum: | |
px_pred_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_momentum"][:, 0], labels | |
) | |
py_pred_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_momentum"][:, 1], labels | |
) | |
pz_pred_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_momentum"][:, 2], labels | |
) | |
ref_pt_px_pred_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_reference_point"][:, 0], labels | |
) | |
ref_pt_py_pred_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_reference_point"][:, 1], labels | |
) | |
ref_pt_pz_pred_pfo = scatter_mean( | |
dic["graph"].ndata["pandora_reference_point"][:, 2], labels | |
) | |
pandora_pid = scatter_mean( | |
dic["graph"].ndata["pandora_pid"], labels | |
) | |
ref_pt_pred_pfo = torch.stack( | |
(ref_pt_px_pred_pfo, ref_pt_py_pred_pfo, ref_pt_pz_pred_pfo), dim=1 | |
) | |
# p_pred_pandora = scatter_mean(dic["graph"].ndata["pandora_momentum"], labels) | |
p_pred_pandora = torch.stack((px_pred_pfo, py_pred_pfo, pz_pred_pfo), dim=1) | |
p_size_pandora = torch.norm(p_pred_pandora, dim=1) | |
pxyz_pred_pfo = ( | |
p_pred_pandora # / torch.norm(p_pred_pandora, dim=1).view(-1, 1) | |
) | |
else: | |
if e_corr is None: | |
corrections_per_shower = get_correction_per_shower(labels, dic) | |
e_pred_showers_cali = e_pred_showers * corrections_per_shower | |
else: | |
corrections_per_shower = e_corr.view(-1) | |
if number_of_fakes > 0: | |
corrections_per_shower_fakes = corrections_per_shower[-number_of_fakes:] | |
corrections_per_shower = corrections_per_shower[:-number_of_fakes] | |
e_reco_showers = scatter_add( | |
dic["graph"].ndata["e_hits"].view(-1), | |
dic["graph"].ndata["particle_number"].long(), | |
) | |
row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long() | |
col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long() | |
if torch.sum(particle_ids == 0) > 0: | |
# particle id can be 0 because there is noise | |
# then row ind 0 in any case corresponds to particle 1. | |
# if there is particle_id 0 then row_ind should be +1? | |
row_ind_ = row_ind - 1 | |
else: | |
# if there is no zero then index 0 corresponds to particle 1. | |
row_ind_ = row_ind | |
pred_showers = shower_p_unique | |
energy_t = ( | |
dic["part_true"].E_corrected.view(-1).to(e_pred_showers.device) | |
) # dic["part_true"][:, 3].to(e_pred_showers.device) | |
vertex = dic["part_true"].vertex.to(e_pred_showers.device) | |
pos_t = dic["part_true"].coord.to(e_pred_showers.device) | |
pid_t = dic["part_true"].pid.to(e_pred_showers.device) | |
is_track_per_shower = scatter_add((dic["graph"].ndata["hit_type"] == 1), labels).int() | |
is_track = torch.zeros(energy_t.shape).to(e_pred_showers.device) | |
if shap: | |
matched_shap_vals = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * ( | |
torch.nan | |
) | |
matched_shap_vals = matched_shap_vals.numpy() | |
matched_ec_x = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * (torch.nan) | |
matched_ec_x = matched_ec_x.numpy() | |
index_matches = col_ind + 1 | |
index_matches = index_matches.to(e_pred_showers.device).long() | |
matched_es = torch.zeros_like(energy_t) * (torch.nan) | |
matched_positions = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
matched_positions = matched_positions.to(e_pred_showers.device) | |
matched_ref_pt = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
matched_ref_pt = matched_ref_pt.to(e_pred_showers.device) | |
matched_pid = torch.zeros_like(energy_t) * (torch.nan) | |
matched_pid = matched_pid.to(e_pred_showers.device).long() | |
matched_positions_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
matched_positions_pfo = matched_positions_pfo.to(e_pred_showers.device) | |
matched_pandora_pid = (torch.zeros((energy_t.shape[0])) * (torch.nan)).to(e_pred_showers.device) | |
matched_ref_pts_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) | |
matched_ref_pts_pfo = matched_ref_pts_pfo.to(e_pred_showers.device) | |
matched_es = matched_es.to(e_pred_showers.device) | |
matched_es[row_ind_] = e_pred_showers[index_matches] | |
if pandora: | |
matched_es_cali = matched_es.clone() | |
matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches] | |
matched_es_cali_pfo = matched_es.clone() | |
matched_es_cali_pfo[row_ind_] = e_pred_showers_pfo[index_matches] | |
matched_pandora_pid[row_ind_] = pandora_pid[index_matches] | |
if calc_pandora_momentum: | |
matched_positions_pfo[row_ind_] = pxyz_pred_pfo[index_matches] | |
matched_ref_pts_pfo[row_ind_] = ref_pt_pred_pfo[index_matches] | |
is_track[row_ind_] = is_track_per_shower[index_matches].float() | |
else: | |
if e_corr is None: | |
matched_es_cali = matched_es.clone() | |
matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches] | |
calibration_per_shower = matched_es.clone() | |
calibration_per_shower[row_ind_] = corrections_per_shower[index_matches] | |
else: | |
matched_es_cali = matched_es.clone() | |
number_of_showers = e_pred_showers[index_matches].shape[0] # DOESN'T INCLUDE THE FAKE SHOWERS | |
#number_of_fake_showers = e_pred_showers.shape[0] - number_of_showers | |
matched_es_cali[row_ind_] = ( | |
corrections_per_shower[ | |
number_of_showers_total : number_of_showers_total | |
+ number_of_showers | |
] | |
#* e_pred_showers[index_matches] | |
) | |
# if len(row_ind) and len(index_matches): | |
# assert row_ind.max() < len(is_track) | |
# assert index_matches.max() < len(is_track_per_shower) | |
is_track[row_ind_] = is_track_per_shower[index_matches].float() | |
if pred_pos is not None: | |
matched_positions[row_ind_] = pred_pos[number_of_showers_total : number_of_showers_total | |
+ number_of_showers] | |
matched_ref_pt[row_ind_] = pred_ref_pt[number_of_showers_total : number_of_showers_total + number_of_showers] | |
matched_pid[row_ind_] = pred_pid[number_of_showers_total : number_of_showers_total + number_of_showers] | |
if shap: | |
matched_shap_vals[row_ind_.cpu()] = shap_vals[index_matches.cpu()] | |
matched_ec_x[row_ind_.cpu()] = ec_x[index_matches.cpu()] | |
calibration_per_shower = matched_es.clone() | |
calibration_per_shower[row_ind_] = corrections_per_shower[ | |
number_of_showers_total : number_of_showers_total + number_of_showers | |
] | |
number_of_showers_total = number_of_showers_total + number_of_showers | |
intersection_E = torch.zeros_like(energy_t) * (torch.nan) | |
if len(col_ind) > 0: | |
ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind, dic) | |
intersection_E[row_ind_] = ie_e.to(e_pred_showers.device) | |
pred_showers[index_matches] = -1 | |
pred_showers[ | |
0 | |
] = ( | |
-1 | |
) # This takes into account that the class 0 for pandora and for dbscan is noise | |
mask = pred_showers != -1 | |
number_of_fake_showers = mask.sum() | |
fakes_in_event = mask.sum() | |
fake_showers_e = e_pred_showers[mask] | |
if e_corr is None or pandora: | |
fake_showers_e_cali = e_pred_showers_cali[mask] | |
# fakes_positions = dic["graph"].ndata["coords"][mask] | |
else: | |
#fake_showers_e_cali = corrections_per_shower[number_of_showers_total:number_of_showers_total+number_of_showers][mask]# * (torch.nan) | |
#fakes_positions = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
#fake_showers_e_cali = fake_showers_e | |
#fakes_pid_pred = torch.zeros((fake_showers_e.shape[0])) * (torch.nan) # just for now for debugigng | |
#fakes_positions = fakes_positions.to(e_pred_showers.device) | |
#fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device) | |
fakes_positions = pred_pos[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
fake_showers_e_cali = e_corr[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
fakes_pid_pred = pred_pid[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
fake_showers_e_reco = e_reco_showers[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] | |
fakes_positions = fakes_positions.to(e_pred_showers.device) | |
fake_showers_e_cali = fake_showers_e_cali.to(e_pred_showers.device) | |
fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device) | |
fake_showers_e_reco = fake_showers_e_reco.to(e_pred_showers.device) | |
#fakes_pid_pred = pred_pid[number_of_showers_total:number_of_showers_total+number_of_showers][mask] | |
#fakes_positions = fakes_positions.to(e_pred_showers.device) | |
if pandora: | |
fake_pandora_pid = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device) | |
fake_pandora_pid = pandora_pid[mask] | |
if calc_pandora_momentum: | |
fake_positions_pfo = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
fake_positions_pfo = fake_positions_pfo.to(e_pred_showers.device) | |
fake_positions_pfo = pxyz_pred_pfo[mask] | |
fakes_positions_ref = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device) | |
fakes_positions_ref = ref_pt_pred_pfo[mask] | |
if not pandora: | |
if e_corr is None: | |
fake_showers_e_cali_factor = corrections_per_shower[mask] | |
else: | |
fake_showers_e_cali_factor = fake_showers_e_cali | |
fake_showers_showers_e_truw = torch.zeros((fake_showers_e.shape[0])) * ( | |
torch.nan | |
) | |
fake_showers_vertex = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
fakes_is_track = (torch.zeros((fake_showers_e.shape[0])) * (torch.nan)).to(e_pred_showers.device) | |
fakes_is_track = is_track_per_shower[mask] | |
fakes_positions_t = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) | |
if not pandora: | |
number_of_fake_showers_total = number_of_fake_showers_total + number_of_fake_showers | |
"""if shap: | |
fake_showers_shap_vals = torch.zeros((fake_showers_e.shape[0], shap_vals_t.shape[1])) * ( | |
torch.nan | |
) | |
fake_showers_ec_x_t = torch.zeros((fake_showers_e.shape[0], ec_x_t.shape[1])) * ( | |
torch.nan | |
) | |
#fake_showers_shap_vals = fake_showers_shap_vals.to(e_pred_showers.device) | |
#fake_showers_ec_x_t = fake_showers_ec_x_t.to(e_pred_showers.device) | |
shap_vals_t = torch.cat((torch.tensor(shap_vals_t), fake_showers_shap_vals), dim=0) | |
ec_x_t = torch.cat((torch.tensor(ec_x_t), fake_showers_ec_x_t), dim=0) | |
""" | |
fake_showers_showers_e_truw = fake_showers_showers_e_truw.to( | |
e_pred_showers.device | |
) | |
fakes_positions_t = fakes_positions_t.to(e_pred_showers.device) | |
fake_showers_vertex = fake_showers_vertex.to(e_pred_showers.device) | |
energy_t = torch.cat( | |
(energy_t, fake_showers_showers_e_truw), | |
dim=0, | |
) | |
vertex = torch.cat((vertex, fake_showers_vertex), dim=0) | |
pid_t = torch.cat( | |
(pid_t.view(-1), fake_showers_showers_e_truw), | |
dim=0, | |
) | |
pos_t = torch.cat( | |
(pos_t, fakes_positions_t), | |
dim=0, | |
) | |
e_reco = torch.cat((e_reco_showers[1:], fake_showers_showers_e_truw), dim=0) | |
e_pred = torch.cat((matched_es, fake_showers_e), dim=0) | |
e_pred_cali = torch.cat((matched_es_cali, fake_showers_e_cali), dim=0) | |
if pred_pos is not None: | |
e_pred_pos = torch.cat((matched_positions, fakes_positions), dim=0) | |
e_pred_pid = torch.cat((matched_pid, fakes_pid_pred), dim=0) | |
e_pred_ref_pt = torch.cat((matched_ref_pt, fakes_positions), dim=0) | |
if pandora: | |
e_pred_cali_pfo = torch.cat( | |
(matched_es_cali_pfo, fake_showers_e_cali), dim=0 | |
) | |
positions_pfo = torch.cat((matched_positions_pfo, fake_positions_pfo), dim=0) | |
pandora_pid = torch.cat((matched_pandora_pid, fake_pandora_pid), dim=0) | |
ref_pts_pfo = torch.cat((matched_ref_pts_pfo, fakes_positions_ref), dim=0) | |
if not pandora: | |
calibration_factor = torch.cat( | |
(calibration_per_shower, fake_showers_e_cali_factor), dim=0 | |
) | |
if shap: | |
# pad | |
matched_shap_vals = torch.cat( | |
( | |
torch.tensor(matched_shap_vals), | |
torch.zeros((fake_showers_e.shape[0], shap_vals.shape[1])), | |
), | |
dim=0, | |
) | |
matched_ec_x = torch.cat( | |
( | |
torch.tensor(matched_ec_x), | |
torch.zeros((fake_showers_e.shape[0], ec_x.shape[1])), | |
), | |
dim=0, | |
) | |
e_pred_t = torch.cat( | |
( | |
intersection_E, | |
torch.zeros_like(fake_showers_e) * (torch.nan), | |
), | |
dim=0, | |
) | |
# e_pred_t_pandora = torch.cat( | |
# ( | |
# intersection_E, | |
# torch.zeros_like(fake_showers_e) * (-200), | |
# torch.zeros_like(fake_showers_e_pandora) * (-100), | |
# ), | |
# dim=0, | |
# ) | |
is_track = torch.cat((is_track, fakes_is_track.to(is_track.device)), dim=0) | |
if pandora: | |
d = { | |
"true_showers_E": energy_t.detach().cpu(), | |
"reco_showers_E": e_reco.detach().cpu(), | |
"pred_showers_E": e_pred.detach().cpu(), | |
"e_pred_and_truth": e_pred_t.detach().cpu(), | |
"pandora_calibrated_E": e_pred_cali.detach().cpu(), | |
"pandora_calibrated_pfo": e_pred_cali_pfo.detach().cpu(), | |
"pandora_calibrated_pos": positions_pfo.detach().cpu().tolist(), | |
"pandora_ref_pt": ref_pts_pfo.detach().cpu().tolist(), | |
"pid": pid_t.detach().cpu(), | |
"pandora_pid":pandora_pid.detach().cpu(), | |
"step": torch.ones_like(energy_t.detach().cpu()) * step, | |
"number_batch": torch.ones_like(energy_t.detach().cpu()) | |
* number_in_batch, | |
"is_track_in_cluster": is_track.detach().cpu(), | |
"vertex": vertex.detach().cpu().tolist() | |
} | |
else: | |
d = { | |
"true_showers_E": energy_t.detach().cpu(), | |
"reco_showers_E": e_reco.detach().cpu(), | |
"pred_showers_E": e_pred.detach().cpu(), | |
"e_pred_and_truth": e_pred_t.detach().cpu(), | |
"pid": pid_t.detach().cpu(), | |
"calibration_factor": calibration_factor.detach().cpu(), | |
"calibrated_E": e_pred_cali.detach().cpu(), | |
"step": torch.ones_like(energy_t.detach().cpu()) * step, | |
"number_batch": torch.ones_like(energy_t.detach().cpu()) | |
* number_in_batch, | |
"is_track_in_cluster": is_track.detach().cpu(), | |
"vertex": vertex.detach().cpu().tolist() | |
} | |
if pred_pos is not None: | |
pred_pos1 = e_pred_pos.detach().cpu() | |
pred_pid1 = e_pred_pid.detach().cpu() | |
pred_ref_pt1 = e_pred_ref_pt.detach().cpu() | |
d["pred_pos_matched"] = ( | |
pred_pos1.tolist() | |
) # Otherwise it doesn't work nicely with Pandas DataFrames | |
d["pred_pid_matched"] = pred_pid1.tolist() | |
d["pred_ref_pt_matched"] = pred_ref_pt1.tolist() | |
"""if shap: | |
print("Adding ec_x and shap_values to the DataFrame") | |
d["ec_x"] = ec_x_t | |
d["shap_values"] = shap_vals_t""" | |
if shap: | |
d["shap_values"] = matched_shap_vals.tolist() | |
d["ec_x"] = matched_ec_x.tolist() | |
d["true_pos"] = pos_t.detach().cpu().tolist() | |
df = pd.DataFrame(data=d) | |
if save_plots_to_folder: | |
event_numbers = np.unique(df.number_batch) | |
for evt in event_numbers: | |
if len(df[df.number_batch == evt]): | |
# Random string | |
rndstr = generate_random_string(5) | |
plot_event( | |
df[df.number_batch == evt], | |
pandora, | |
save_plots_to_folder + str(evt) + rndstr, | |
graph=dic["graph"].to("cpu"), | |
y=dic["part_true"], | |
labels=dic["graph"].ndata["particle_number"].long(), | |
is_track_in_cluster=df.is_track_in_cluster | |
) | |
'''plot_event( | |
df[df.number_batch == evt], | |
pandora, | |
save_plots_to_folder + "_CLUSTERING_" + str(evt) + rndstr, | |
graph=dic["graph"].to("cpu"), | |
y=dic["part_true"], | |
labels=labels.detach().cpu(), | |
is_track_in_cluster=df.is_track_in_cluster | |
)''' | |
if number_of_showers_total is None: | |
return df | |
else: | |
return df, number_of_showers_total, number_of_fake_showers_total | |
else: | |
return [], 0, 0 | |
def get_correction_per_shower(labels, dic): | |
unique_labels = torch.unique(labels) | |
list_corr = [] | |
for ii, pred_label in enumerate(unique_labels): | |
if ii == 0: | |
if pred_label != 0: | |
list_corr.append(dic["graph"].ndata["correction"][0].view(-1) * 0) | |
mask = labels == pred_label | |
corrections_E_label = dic["graph"].ndata["correction"][mask] | |
betas_label_indmax = torch.argmax(dic["graph"].ndata["beta"][mask]) | |
list_corr.append(corrections_E_label[betas_label_indmax].view(-1)) | |
corrections = torch.cat(list_corr, dim=0) | |
return corrections | |
def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits): | |
len_pred_showers = len(shower_p_unique) | |
intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to( | |
shower_p_unique.device | |
) | |
intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to( | |
shower_p_unique.device | |
) | |
for index, id in enumerate(particle_ids): | |
counts = torch.zeros_like(labels) | |
mask_p = dic["graph"].ndata["particle_number"] == id | |
h_hits = e_hits.clone() | |
counts[mask_p] = 1 | |
h_hits[~mask_p] = 0 | |
intersection_matrix[:, index] = scatter_add(counts, labels) | |
intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device)) | |
return intersection_matrix, intersection_matrix_w | |
def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic): | |
len_pred_showers = len(shower_p_unique) | |
union_matrix = torch.zeros((len_pred_showers, len(particle_ids))) | |
for index, id in enumerate(particle_ids): | |
counts = torch.zeros_like(labels) | |
mask_p = dic["graph"].ndata["particle_number"] == id | |
for index_pred, id_pred in enumerate(shower_p_unique): | |
mask_pred_p = labels == id_pred | |
mask_union = mask_pred_p + mask_p | |
union_matrix[index_pred, index] = torch.sum(mask_union) | |
return union_matrix | |
def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.7, td=0.03): | |
""" | |
Returns a clustering of hits -> cluster_index, based on the GravNet model | |
output (predicted betas and cluster space coordinates) and the clustering | |
parameters tbeta and td. | |
Takes torch.Tensors as input. | |
""" | |
n_points = betas.size(0) | |
select_condpoints = betas > tbeta | |
# Get indices passing the threshold | |
indices_condpoints = select_condpoints.nonzero() | |
# Order them by decreasing beta value | |
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
# Assign points to condensation points | |
# Only assign previously unassigned points (no overwriting) | |
# Points unassigned at the end are bkg (-1) | |
unassigned = torch.arange(n_points).to(betas.device) | |
clustering = -1 * torch.ones(n_points, dtype=torch.long).to(betas.device) | |
while len(indices_condpoints) > 0 and len(unassigned) > 0: | |
index_condpoint = indices_condpoints[0] | |
d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) | |
assigned_to_this_condpoint = unassigned[d < td] | |
clustering[assigned_to_this_condpoint] = index_condpoint[0] | |
unassigned = unassigned[~(d < td)] | |
# calculate indices_codpoints again | |
indices_condpoints = find_condpoints(betas, unassigned, tbeta) | |
return clustering | |
def find_condpoints(betas, unassigned, tbeta): | |
n_points = betas.size(0) | |
select_condpoints = betas > tbeta | |
device = betas.device | |
mask_unassigned = torch.zeros(n_points).to(device) | |
mask_unassigned[unassigned] = True | |
select_condpoints = mask_unassigned.to(bool) * select_condpoints | |
# Get indices passing the threshold | |
indices_condpoints = select_condpoints.nonzero() | |
# Order them by decreasing beta value | |
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
return indices_condpoints | |
def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind, dic): | |
list_intersection_E = [] | |
# intersection_matrix_w = intersection_matrix_w | |
particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) | |
if torch.sum(particle_ids == 0) > 0: | |
# removing also the MC particle corresponding to noise | |
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, 1:], 1, 0) | |
row_ind = row_ind - 1 | |
else: | |
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0) | |
for i in range(0, len(col_ind)): | |
list_intersection_E.append( | |
intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1) | |
) | |
if len(list_intersection_E) > 0: | |
return torch.cat(list_intersection_E, dim=0) | |
else: | |
return 0 | |
def plot_iou_matrix(iou_matrix, image_path, hdbscan=False): | |
iou_matrix = torch.transpose(iou_matrix[1:, :], 1, 0) | |
fig, ax = plt.subplots() | |
iou_matrix = iou_matrix.detach().cpu().numpy() | |
ax.matshow(iou_matrix, cmap=plt.cm.Blues) | |
for i in range(0, iou_matrix.shape[1]): | |
for j in range(0, iou_matrix.shape[0]): | |
c = np.round(iou_matrix[j, i], 1) | |
ax.text(i, j, str(c), va="center", ha="center") | |
fig.savefig(image_path, bbox_inches="tight") | |
if hdbscan: | |
wandb.log({"iou_matrix_hdbscan": wandb.Image(image_path)}) | |
else: | |
wandb.log({"iou_matrix": wandb.Image(image_path)}) | |
def match_showers( | |
labels, | |
dic, | |
particle_ids, | |
model_output, | |
local_rank, | |
i, | |
path_save, | |
pandora=False, | |
tracks=False, | |
hdbscan=False, | |
): | |
iou_threshold = 0.25 | |
shower_p_unique = torch.unique(labels) | |
if torch.sum(labels == 0) == 0: | |
shower_p_unique = torch.cat( | |
( | |
torch.Tensor([0]).to(shower_p_unique.device).view(-1), | |
shower_p_unique.view(-1), | |
), | |
dim=0, | |
) | |
e_hits = dic["graph"].ndata["e_hits"].view(-1) | |
i_m, i_m_w = obtain_intersection_matrix( | |
shower_p_unique, particle_ids, labels, dic, e_hits | |
) | |
i_m = i_m.to(model_output.device) | |
i_m_w = i_m_w.to(model_output.device) | |
u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic) | |
u_m = u_m.to(model_output.device) | |
iou_matrix = i_m / u_m | |
if torch.sum(particle_ids == 0) > 0: | |
# removing also the MC particle corresponding to noise | |
iou_matrix_num = ( | |
torch.transpose(iou_matrix[1:, 1:], 1, 0).clone().detach().cpu().numpy() | |
) | |
else: | |
iou_matrix_num = ( | |
torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy() | |
) | |
iou_matrix_num[iou_matrix_num < iou_threshold] = 0 | |
row_ind, col_ind = linear_sum_assignment(-iou_matrix_num) | |
# Next three lines remove solutions where there is a shower that is not associated and iou it's zero (or less than threshold) | |
mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0 | |
row_ind = row_ind[mask_matching_matrix] | |
col_ind = col_ind[mask_matching_matrix] | |
if torch.sum(particle_ids == 0) > 0: | |
row_ind = row_ind + 1 | |
if i == 0 and local_rank == 0: | |
if path_save is not None: | |
if pandora: | |
image_path = path_save + "/example_1_clustering_pandora.png" | |
else: | |
image_path = path_save + "/example_1_clustering.png" | |
# plot_iou_matrix(iou_matrix, image_path, hdbscan) | |
# row_ind are particles that are matched and col_ind the ind of preds they are matched to | |
return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix | |
def clustering_obtain_labels(X, betas, device): | |
clustering = get_clustering(betas, X) | |
map_from = list(np.unique(clustering.detach().cpu())) | |
cluster_id = map(lambda x: map_from.index(x), clustering.detach().cpu()) | |
clustering_ordered = torch.Tensor(list(cluster_id)).long() | |
if torch.unique(clustering)[0] != -1: | |
clustering = clustering_ordered + 1 | |
else: | |
clustering = clustering_ordered | |
clustering = torch.Tensor(clustering.view(-1)).long().to(device) | |
return clustering | |
def hfdb_obtain_labels(X, device, eps=0.1): | |
hdb = HDBSCAN(min_cluster_size=8, min_samples=8, cluster_selection_epsilon=eps).fit( | |
X.detach().cpu() | |
) | |
labels_hdb = hdb.labels_ + 1 | |
labels_hdb = np.reshape(labels_hdb, (-1)) | |
labels_hdb = torch.Tensor(labels_hdb).long().to(device) | |
return labels_hdb | |
def dbscan_obtain_labels(X, device): | |
distance_scale = ( | |
(torch.min(torch.abs(torch.min(X, dim=0)[0] - torch.max(X, dim=0)[0])) / 30) | |
.view(-1) | |
.detach() | |
.cpu() | |
.numpy()[0] | |
) | |
db = DBSCAN(eps=distance_scale, min_samples=15).fit(X.detach().cpu()) | |
# DBSCAN has clustering labels -1,0,.., our cluster 0 is noise so we add 1 | |
labels = db.labels_ + 1 | |
labels = np.reshape(labels, (-1)) | |
labels = torch.Tensor(labels).long().to(device) | |
return labels | |
class CachedIndexList: | |
def __init__(self, lst): | |
self.lst = lst | |
self.cache = {} | |
def index(self, value): | |
if value in self.cache: | |
return self.cache[value] | |
else: | |
idx = self.lst.index(value) | |
self.cache[value] = idx | |
return idx | |
def get_labels_pandora(tracks, dic, device): | |
if tracks: | |
labels_pandora = dic["graph"].ndata["pandora_pfo"].long() | |
else: | |
labels_pandora = dic["graph"].ndata["pandora_cluster"].long() | |
labels_pandora = labels_pandora + 1 | |
map_from = list(np.unique(labels_pandora.detach().cpu())) | |
map_from = CachedIndexList(map_from) | |
cluster_id = map(lambda x: map_from.index(x), labels_pandora.detach().cpu().numpy()) | |
labels_pandora = torch.Tensor(list(cluster_id)).long().to(device) | |
return labels_pandora | |