import dgl import torch import os from alembic.command import current from sklearn.cluster import DBSCAN, HDBSCAN from torch_scatter import scatter_max, scatter_add, scatter_mean import numpy as np from src.dataset.functions_data import CachedIndexList, spherical_to_cartesian import matplotlib.pyplot as plt from scipy.optimize import linear_sum_assignment import pandas as pd import wandb from src.utils.inference.per_particle_metrics import plot_event import random import string def generate_random_string(length): letters = string.ascii_letters + string.digits return "".join(random.choice(letters) for i in range(length)) def create_and_store_graph_output( batch_g, model_output, y, local_rank, step, epoch, path_save, store=False, predict=False, tracking=False, e_corr=None, shap_vals=None, ec_x=None, # ec_x: "global" features (what gets inputted into the final deep neural network head) for energy correction tracks=False, store_epoch=False, total_number_events=0, pred_pos=None, pred_ref_pt=None, use_gt_clusters=False, pids_neutral=None, pids_charged=None, pred_pid=None, pred_xyz_track=None, number_of_fakes=None ): number_of_showers_total = 0 number_of_showers_total1 = 0 number_of_fake_showers_total1 = 0 batch_g.ndata["coords"] = model_output[:, 0:3] batch_g.ndata["beta"] = model_output[:, 3] if not tracking: if e_corr is None: batch_g.ndata["correction"] = model_output[:, 4] graphs = dgl.unbatch(batch_g) batch_id = y.batch_number.view(-1) # y[:, -1].view(-1) df_list = [] df_list1 = [] df_list_pandora = [] total_number_candidates = 0 for i in range(0, len(graphs)): mask = batch_id == i dic = {} dic["graph"] = graphs[i] y1 = y.copy() y1.mask(mask) dic["part_true"] = y1 # y[mask] X = dic["graph"].ndata["coords"] # if shap_vals is not None: # dic["shap_values"] = shap_vals # if ec_x is not None: # dic["ec_x"] = ec_x ## ? No mask ?!? if predict: labels_clustering = clustering_obtain_labels( X, dic["graph"].ndata["beta"].view(-1), model_output.device ) if use_gt_clusters: labels_hdb = dic["graph"].ndata["particle_number"].type(torch.int64) else: labels_hdb = hfdb_obtain_labels(X, model_output.device) num_clusters = len(labels_hdb.unique()) #if labels_hdb.min() == 0 and labels_hdb.sum() == 0: # labels_hdb += 1 # Quick hack # raise Exception("!!!! Labels==0 !!!!") if predict: labels_pandora = get_labels_pandora(tracks, dic, model_output.device) num_clusters_pandora = len(labels_pandora.unique()) particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) #current_number_candidates = num_clusters #pred_pos_batch = pred_pos[total_number_candidates:total_number_candidates+current_number_candidates] #pred_ref_pt_batch = pred_ref_pt[total_number_candidates:total_number_candidates+current_number_candidates] #pred_pid_batch = pred_pid[total_number_candidates:total_number_candidates+current_number_candidates] #e_corr_batch = e_corr[total_number_candidates:total_number_candidates+current_number_candidates] """if predict: shower_p_unique = torch.unique(labels_clustering) shower_p_unique, row_ind, col_ind, i_m_w, iou_m_c = match_showers( labels_clustering, dic, particle_ids, model_output, local_rank, i, path_save, tracks=tracks, )""" shower_p_unique_hdb, row_ind_hdb, col_ind_hdb, i_m_w_hdb, iou_m = match_showers( labels_hdb, dic, particle_ids, model_output, local_rank, i, path_save, tracks=tracks, hdbscan=True, ) if predict: ( shower_p_unique_pandora, row_ind_pandora, col_ind_pandora, i_m_w_pandora, iou_m_pandora, ) = match_showers( labels_pandora, dic, particle_ids, model_output, local_rank, i, path_save, pandora=True, tracks=tracks, ) # # if len(row_ind_hdb) < len(dic["part_true"]): # print(len(row_ind_hdb), len(dic["part_true"])) # print("storing event", local_rank, step, i) # path_graphs_all_comparing = os.path.join(path_save, "graphs_all_comparing") # if not os.path.exists(path_graphs_all_comparing): # os.makedirs(path_graphs_all_comparing) '''torch.save( dic, path_save + "/graphs_all_comparing_Gregor/" + str(local_rank) + "_" + str(step) + "_" + str(i) + ".pt", )''' # torch.save( # dic, # path_save # + "/graphs/" # + str(local_rank) # + "_" # + str(step) # + "_" # + str(i) # + ".pt", # ) if len(shower_p_unique_hdb) > 1: # df_event, number_of_showers_total = generate_showers_data_frame( # labels_clustering, # labels_clustering, # dic, # shower_p_unique, # particle_ids, # row_ind, # col_ind, # i_m_w, # e_corr=e_corr, # number_of_showers_total=number_of_showers_total, # step=step, # number_in_batch=i, # tracks=tracks, # ) # if pred_pos is not None: # Apply temporary correction import math # phi = math.atan2(pred_pos[:, 1], pred_pos[:, 0]) # phi = torch.atan2(pred_pos[:, 1], pred_pos[:, 0]) # theta = torch.acos(pred_pos[:, 2] / torch.norm(pred_pos, dim=1)) # pred_pos = spherical_to_cartesian(theta, phi, torch.norm(pred_pos, dim=1), normalized=True) # pred_pos= pred_pos.to(model_output.device) df_event1, number_of_showers_total1, number_of_fake_showers_total1 = generate_showers_data_frame( labels_hdb, dic, shower_p_unique_hdb, particle_ids, row_ind_hdb, col_ind_hdb, i_m_w_hdb, e_corr=e_corr, number_of_showers_total=number_of_showers_total1, step=step, number_in_batch=total_number_events, tracks=tracks, ec_x=ec_x, shap_vals=shap_vals, pred_pos=pred_pos, pred_ref_pt=pred_ref_pt, pred_pid=pred_pid, save_plots_to_folder=path_save + "/ML_Model_evt_plots_debugging", number_of_fakes=number_of_fakes, number_of_fake_showers_total=number_of_fake_showers_total1, ) if len(df_event1) > 1: df_list1.append(df_event1) if predict: df_event_pandora = generate_showers_data_frame( labels_pandora, dic, shower_p_unique_pandora, particle_ids, row_ind_pandora, col_ind_pandora, i_m_w_pandora, pandora=True, tracking=tracking, step=step, number_in_batch=total_number_events, tracks=tracks, save_plots_to_folder=path_save + "/Pandora_evt_plots_debugging", ) if df_event_pandora is not None and type(df_event_pandora) is not tuple: df_list_pandora.append(df_event_pandora) total_number_events = total_number_events + 1 # print("number of showers total", number_of_showers_total) # number_of_showers_total = number_of_showers_total + len(shower_p_unique_hdb) # print("number of showers total", number_of_showers_total) df_batch1 = pd.concat(df_list1) if predict: df_batch_pandora = pd.concat(df_list_pandora) else: df_batch = [] df_batch_pandora = [] # if store: store_at_batch_end( path_save, df_batch1, df_batch_pandora, # df_batch, local_rank, step, epoch, predict=predict, store=store_epoch, ) if predict: return df_batch_pandora, df_batch1, total_number_events else: return df_batch1 def store_at_batch_end( path_save, df_batch1, df_batch_pandora, # df_batch, local_rank=0, step=0, epoch=None, predict=False, store=False, ): if predict: path_save_ = ( path_save + "/" + str(local_rank) + "_" + str(step) + "_" + str(epoch) + ".pt" ) # if store and predict: # df_batch.to_pickle(path_save_) # log_efficiency(df_batch, clustering=True) path_save_ = ( path_save + "/" + str(local_rank) + "_" + str(step) + "_" + str(epoch) + "_hdbscan.pt" ) if store and predict: df_batch1.to_pickle(path_save_) if predict: path_save_pandora = ( path_save + "/" + str(local_rank) + "_" + str(step) + "_" + str(epoch) + "_pandora.pt" ) if store and predict: df_batch_pandora.to_pickle(path_save_pandora) log_efficiency(df_batch1) if predict: log_efficiency(df_batch_pandora, pandora=True) def log_efficiency(df, pandora=False, clustering=False): mask = ~np.isnan(df["reco_showers_E"]) eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len( df["pred_showers_E"][mask].values ) if pandora: wandb.log({"efficiency validation pandora": eff}) elif clustering: wandb.log({"efficiency validation clustering": eff}) else: wandb.log({"efficiency validation": eff}) def generate_showers_data_frame( labels, dic, shower_p_unique, particle_ids, row_ind, col_ind, i_m_w, pandora=False, tracking=False, e_corr=None, number_of_showers_total=None, step=0, number_in_batch=0, tracks=False, shap_vals=None, ec_x=None, pred_pos=None, pred_pid=None, save_plots_to_folder="", pred_ref_pt=None, number_of_fake_showers_total=None, number_of_fakes=None ): shap = shap_vals is not None e_pred_showers = scatter_add(dic["graph"].ndata["e_hits"].view(-1), labels) if pandora: e_pred_showers_cali = scatter_mean( dic["graph"].ndata["pandora_cluster_energy"].view(-1), labels ) e_pred_showers_pfo = scatter_mean( dic["graph"].ndata["pandora_pfo_energy"].view(-1), labels ) # px_pred_pfo = scatter_mean(dic["graph"].ndata["hit_px"], labels) # py_pred_pfo = scatter_mean(dic["graph"].ndata["hit_py"], labels) # pz_pred_pfo = scatter_mean(dic["graph"].ndata["hit_pz"], labels) # p_pred_pfo = scatter_mean(dic["graph"].ndata["pos_pxpypz"], labels) # FIX THIS: the shape of pos_pxpypz is [-1, 3] calc_pandora_momentum = "pandora_momentum" in dic["graph"].ndata if calc_pandora_momentum: px_pred_pfo = scatter_mean( dic["graph"].ndata["pandora_momentum"][:, 0], labels ) py_pred_pfo = scatter_mean( dic["graph"].ndata["pandora_momentum"][:, 1], labels ) pz_pred_pfo = scatter_mean( dic["graph"].ndata["pandora_momentum"][:, 2], labels ) ref_pt_px_pred_pfo = scatter_mean( dic["graph"].ndata["pandora_reference_point"][:, 0], labels ) ref_pt_py_pred_pfo = scatter_mean( dic["graph"].ndata["pandora_reference_point"][:, 1], labels ) ref_pt_pz_pred_pfo = scatter_mean( dic["graph"].ndata["pandora_reference_point"][:, 2], labels ) pandora_pid = scatter_mean( dic["graph"].ndata["pandora_pid"], labels ) ref_pt_pred_pfo = torch.stack( (ref_pt_px_pred_pfo, ref_pt_py_pred_pfo, ref_pt_pz_pred_pfo), dim=1 ) # p_pred_pandora = scatter_mean(dic["graph"].ndata["pandora_momentum"], labels) p_pred_pandora = torch.stack((px_pred_pfo, py_pred_pfo, pz_pred_pfo), dim=1) p_size_pandora = torch.norm(p_pred_pandora, dim=1) pxyz_pred_pfo = ( p_pred_pandora # / torch.norm(p_pred_pandora, dim=1).view(-1, 1) ) else: if e_corr is None: corrections_per_shower = get_correction_per_shower(labels, dic) e_pred_showers_cali = e_pred_showers * corrections_per_shower else: corrections_per_shower = e_corr.view(-1) if number_of_fakes > 0: corrections_per_shower_fakes = corrections_per_shower[-number_of_fakes:] corrections_per_shower = corrections_per_shower[:-number_of_fakes] e_reco_showers = scatter_add( dic["graph"].ndata["e_hits"].view(-1), dic["graph"].ndata["particle_number"].long(), ) row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long() col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long() if torch.sum(particle_ids == 0) > 0: # particle id can be 0 because there is noise # then row ind 0 in any case corresponds to particle 1. # if there is particle_id 0 then row_ind should be +1? row_ind_ = row_ind - 1 else: # if there is no zero then index 0 corresponds to particle 1. row_ind_ = row_ind pred_showers = shower_p_unique energy_t = ( dic["part_true"].E_corrected.view(-1).to(e_pred_showers.device) ) # dic["part_true"][:, 3].to(e_pred_showers.device) vertex = dic["part_true"].vertex.to(e_pred_showers.device) pos_t = dic["part_true"].coord.to(e_pred_showers.device) pid_t = dic["part_true"].pid.to(e_pred_showers.device) is_track_per_shower = scatter_add((dic["graph"].ndata["hit_type"] == 1), labels).int() is_track = torch.zeros(energy_t.shape).to(e_pred_showers.device) if shap: matched_shap_vals = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * ( torch.nan ) matched_shap_vals = matched_shap_vals.numpy() matched_ec_x = torch.zeros((energy_t.shape[0], ec_x.shape[1])) * (torch.nan) matched_ec_x = matched_ec_x.numpy() index_matches = col_ind + 1 index_matches = index_matches.to(e_pred_showers.device).long() matched_es = torch.zeros_like(energy_t) * (torch.nan) matched_positions = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) matched_positions = matched_positions.to(e_pred_showers.device) matched_ref_pt = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) matched_ref_pt = matched_ref_pt.to(e_pred_showers.device) matched_pid = torch.zeros_like(energy_t) * (torch.nan) matched_pid = matched_pid.to(e_pred_showers.device).long() matched_positions_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) matched_positions_pfo = matched_positions_pfo.to(e_pred_showers.device) matched_pandora_pid = (torch.zeros((energy_t.shape[0])) * (torch.nan)).to(e_pred_showers.device) matched_ref_pts_pfo = torch.zeros((energy_t.shape[0], 3)) * (torch.nan) matched_ref_pts_pfo = matched_ref_pts_pfo.to(e_pred_showers.device) matched_es = matched_es.to(e_pred_showers.device) matched_es[row_ind_] = e_pred_showers[index_matches] if pandora: matched_es_cali = matched_es.clone() matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches] matched_es_cali_pfo = matched_es.clone() matched_es_cali_pfo[row_ind_] = e_pred_showers_pfo[index_matches] matched_pandora_pid[row_ind_] = pandora_pid[index_matches] if calc_pandora_momentum: matched_positions_pfo[row_ind_] = pxyz_pred_pfo[index_matches] matched_ref_pts_pfo[row_ind_] = ref_pt_pred_pfo[index_matches] is_track[row_ind_] = is_track_per_shower[index_matches].float() else: if e_corr is None: matched_es_cali = matched_es.clone() matched_es_cali[row_ind_] = e_pred_showers_cali[index_matches] calibration_per_shower = matched_es.clone() calibration_per_shower[row_ind_] = corrections_per_shower[index_matches] else: matched_es_cali = matched_es.clone() number_of_showers = e_pred_showers[index_matches].shape[0] # DOESN'T INCLUDE THE FAKE SHOWERS #number_of_fake_showers = e_pred_showers.shape[0] - number_of_showers matched_es_cali[row_ind_] = ( corrections_per_shower[ number_of_showers_total : number_of_showers_total + number_of_showers ] #* e_pred_showers[index_matches] ) # if len(row_ind) and len(index_matches): # assert row_ind.max() < len(is_track) # assert index_matches.max() < len(is_track_per_shower) is_track[row_ind_] = is_track_per_shower[index_matches].float() if pred_pos is not None: matched_positions[row_ind_] = pred_pos[number_of_showers_total : number_of_showers_total + number_of_showers] matched_ref_pt[row_ind_] = pred_ref_pt[number_of_showers_total : number_of_showers_total + number_of_showers] matched_pid[row_ind_] = pred_pid[number_of_showers_total : number_of_showers_total + number_of_showers] if shap: matched_shap_vals[row_ind_.cpu()] = shap_vals[index_matches.cpu()] matched_ec_x[row_ind_.cpu()] = ec_x[index_matches.cpu()] calibration_per_shower = matched_es.clone() calibration_per_shower[row_ind_] = corrections_per_shower[ number_of_showers_total : number_of_showers_total + number_of_showers ] number_of_showers_total = number_of_showers_total + number_of_showers intersection_E = torch.zeros_like(energy_t) * (torch.nan) if len(col_ind) > 0: ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind, dic) intersection_E[row_ind_] = ie_e.to(e_pred_showers.device) pred_showers[index_matches] = -1 pred_showers[ 0 ] = ( -1 ) # This takes into account that the class 0 for pandora and for dbscan is noise mask = pred_showers != -1 number_of_fake_showers = mask.sum() fakes_in_event = mask.sum() fake_showers_e = e_pred_showers[mask] if e_corr is None or pandora: fake_showers_e_cali = e_pred_showers_cali[mask] # fakes_positions = dic["graph"].ndata["coords"][mask] else: #fake_showers_e_cali = corrections_per_shower[number_of_showers_total:number_of_showers_total+number_of_showers][mask]# * (torch.nan) #fakes_positions = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) #fake_showers_e_cali = fake_showers_e #fakes_pid_pred = torch.zeros((fake_showers_e.shape[0])) * (torch.nan) # just for now for debugigng #fakes_positions = fakes_positions.to(e_pred_showers.device) #fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device) fakes_positions = pred_pos[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] fake_showers_e_cali = e_corr[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] fakes_pid_pred = pred_pid[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] fake_showers_e_reco = e_reco_showers[-number_of_fakes:][number_of_fake_showers_total:number_of_fake_showers_total+number_of_fake_showers] fakes_positions = fakes_positions.to(e_pred_showers.device) fake_showers_e_cali = fake_showers_e_cali.to(e_pred_showers.device) fakes_pid_pred = fakes_pid_pred.to(e_pred_showers.device) fake_showers_e_reco = fake_showers_e_reco.to(e_pred_showers.device) #fakes_pid_pred = pred_pid[number_of_showers_total:number_of_showers_total+number_of_showers][mask] #fakes_positions = fakes_positions.to(e_pred_showers.device) if pandora: fake_pandora_pid = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device) fake_pandora_pid = pandora_pid[mask] if calc_pandora_momentum: fake_positions_pfo = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) fake_positions_pfo = fake_positions_pfo.to(e_pred_showers.device) fake_positions_pfo = pxyz_pred_pfo[mask] fakes_positions_ref = (torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan)).to(e_pred_showers.device) fakes_positions_ref = ref_pt_pred_pfo[mask] if not pandora: if e_corr is None: fake_showers_e_cali_factor = corrections_per_shower[mask] else: fake_showers_e_cali_factor = fake_showers_e_cali fake_showers_showers_e_truw = torch.zeros((fake_showers_e.shape[0])) * ( torch.nan ) fake_showers_vertex = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) fakes_is_track = (torch.zeros((fake_showers_e.shape[0])) * (torch.nan)).to(e_pred_showers.device) fakes_is_track = is_track_per_shower[mask] fakes_positions_t = torch.zeros((fake_showers_e.shape[0], 3)) * (torch.nan) if not pandora: number_of_fake_showers_total = number_of_fake_showers_total + number_of_fake_showers """if shap: fake_showers_shap_vals = torch.zeros((fake_showers_e.shape[0], shap_vals_t.shape[1])) * ( torch.nan ) fake_showers_ec_x_t = torch.zeros((fake_showers_e.shape[0], ec_x_t.shape[1])) * ( torch.nan ) #fake_showers_shap_vals = fake_showers_shap_vals.to(e_pred_showers.device) #fake_showers_ec_x_t = fake_showers_ec_x_t.to(e_pred_showers.device) shap_vals_t = torch.cat((torch.tensor(shap_vals_t), fake_showers_shap_vals), dim=0) ec_x_t = torch.cat((torch.tensor(ec_x_t), fake_showers_ec_x_t), dim=0) """ fake_showers_showers_e_truw = fake_showers_showers_e_truw.to( e_pred_showers.device ) fakes_positions_t = fakes_positions_t.to(e_pred_showers.device) fake_showers_vertex = fake_showers_vertex.to(e_pred_showers.device) energy_t = torch.cat( (energy_t, fake_showers_showers_e_truw), dim=0, ) vertex = torch.cat((vertex, fake_showers_vertex), dim=0) pid_t = torch.cat( (pid_t.view(-1), fake_showers_showers_e_truw), dim=0, ) pos_t = torch.cat( (pos_t, fakes_positions_t), dim=0, ) e_reco = torch.cat((e_reco_showers[1:], fake_showers_showers_e_truw), dim=0) e_pred = torch.cat((matched_es, fake_showers_e), dim=0) e_pred_cali = torch.cat((matched_es_cali, fake_showers_e_cali), dim=0) if pred_pos is not None: e_pred_pos = torch.cat((matched_positions, fakes_positions), dim=0) e_pred_pid = torch.cat((matched_pid, fakes_pid_pred), dim=0) e_pred_ref_pt = torch.cat((matched_ref_pt, fakes_positions), dim=0) if pandora: e_pred_cali_pfo = torch.cat( (matched_es_cali_pfo, fake_showers_e_cali), dim=0 ) positions_pfo = torch.cat((matched_positions_pfo, fake_positions_pfo), dim=0) pandora_pid = torch.cat((matched_pandora_pid, fake_pandora_pid), dim=0) ref_pts_pfo = torch.cat((matched_ref_pts_pfo, fakes_positions_ref), dim=0) if not pandora: calibration_factor = torch.cat( (calibration_per_shower, fake_showers_e_cali_factor), dim=0 ) if shap: # pad matched_shap_vals = torch.cat( ( torch.tensor(matched_shap_vals), torch.zeros((fake_showers_e.shape[0], shap_vals.shape[1])), ), dim=0, ) matched_ec_x = torch.cat( ( torch.tensor(matched_ec_x), torch.zeros((fake_showers_e.shape[0], ec_x.shape[1])), ), dim=0, ) e_pred_t = torch.cat( ( intersection_E, torch.zeros_like(fake_showers_e) * (torch.nan), ), dim=0, ) # e_pred_t_pandora = torch.cat( # ( # intersection_E, # torch.zeros_like(fake_showers_e) * (-200), # torch.zeros_like(fake_showers_e_pandora) * (-100), # ), # dim=0, # ) is_track = torch.cat((is_track, fakes_is_track.to(is_track.device)), dim=0) if pandora: d = { "true_showers_E": energy_t.detach().cpu(), "reco_showers_E": e_reco.detach().cpu(), "pred_showers_E": e_pred.detach().cpu(), "e_pred_and_truth": e_pred_t.detach().cpu(), "pandora_calibrated_E": e_pred_cali.detach().cpu(), "pandora_calibrated_pfo": e_pred_cali_pfo.detach().cpu(), "pandora_calibrated_pos": positions_pfo.detach().cpu().tolist(), "pandora_ref_pt": ref_pts_pfo.detach().cpu().tolist(), "pid": pid_t.detach().cpu(), "pandora_pid":pandora_pid.detach().cpu(), "step": torch.ones_like(energy_t.detach().cpu()) * step, "number_batch": torch.ones_like(energy_t.detach().cpu()) * number_in_batch, "is_track_in_cluster": is_track.detach().cpu(), "vertex": vertex.detach().cpu().tolist() } else: d = { "true_showers_E": energy_t.detach().cpu(), "reco_showers_E": e_reco.detach().cpu(), "pred_showers_E": e_pred.detach().cpu(), "e_pred_and_truth": e_pred_t.detach().cpu(), "pid": pid_t.detach().cpu(), "calibration_factor": calibration_factor.detach().cpu(), "calibrated_E": e_pred_cali.detach().cpu(), "step": torch.ones_like(energy_t.detach().cpu()) * step, "number_batch": torch.ones_like(energy_t.detach().cpu()) * number_in_batch, "is_track_in_cluster": is_track.detach().cpu(), "vertex": vertex.detach().cpu().tolist() } if pred_pos is not None: pred_pos1 = e_pred_pos.detach().cpu() pred_pid1 = e_pred_pid.detach().cpu() pred_ref_pt1 = e_pred_ref_pt.detach().cpu() d["pred_pos_matched"] = ( pred_pos1.tolist() ) # Otherwise it doesn't work nicely with Pandas DataFrames d["pred_pid_matched"] = pred_pid1.tolist() d["pred_ref_pt_matched"] = pred_ref_pt1.tolist() """if shap: print("Adding ec_x and shap_values to the DataFrame") d["ec_x"] = ec_x_t d["shap_values"] = shap_vals_t""" if shap: d["shap_values"] = matched_shap_vals.tolist() d["ec_x"] = matched_ec_x.tolist() d["true_pos"] = pos_t.detach().cpu().tolist() df = pd.DataFrame(data=d) if save_plots_to_folder: event_numbers = np.unique(df.number_batch) for evt in event_numbers: if len(df[df.number_batch == evt]): # Random string rndstr = generate_random_string(5) plot_event( df[df.number_batch == evt], pandora, save_plots_to_folder + str(evt) + rndstr, graph=dic["graph"].to("cpu"), y=dic["part_true"], labels=dic["graph"].ndata["particle_number"].long(), is_track_in_cluster=df.is_track_in_cluster ) '''plot_event( df[df.number_batch == evt], pandora, save_plots_to_folder + "_CLUSTERING_" + str(evt) + rndstr, graph=dic["graph"].to("cpu"), y=dic["part_true"], labels=labels.detach().cpu(), is_track_in_cluster=df.is_track_in_cluster )''' if number_of_showers_total is None: return df else: return df, number_of_showers_total, number_of_fake_showers_total else: return [], 0, 0 def get_correction_per_shower(labels, dic): unique_labels = torch.unique(labels) list_corr = [] for ii, pred_label in enumerate(unique_labels): if ii == 0: if pred_label != 0: list_corr.append(dic["graph"].ndata["correction"][0].view(-1) * 0) mask = labels == pred_label corrections_E_label = dic["graph"].ndata["correction"][mask] betas_label_indmax = torch.argmax(dic["graph"].ndata["beta"][mask]) list_corr.append(corrections_E_label[betas_label_indmax].view(-1)) corrections = torch.cat(list_corr, dim=0) return corrections def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits): len_pred_showers = len(shower_p_unique) intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to( shower_p_unique.device ) intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to( shower_p_unique.device ) for index, id in enumerate(particle_ids): counts = torch.zeros_like(labels) mask_p = dic["graph"].ndata["particle_number"] == id h_hits = e_hits.clone() counts[mask_p] = 1 h_hits[~mask_p] = 0 intersection_matrix[:, index] = scatter_add(counts, labels) intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device)) return intersection_matrix, intersection_matrix_w def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic): len_pred_showers = len(shower_p_unique) union_matrix = torch.zeros((len_pred_showers, len(particle_ids))) for index, id in enumerate(particle_ids): counts = torch.zeros_like(labels) mask_p = dic["graph"].ndata["particle_number"] == id for index_pred, id_pred in enumerate(shower_p_unique): mask_pred_p = labels == id_pred mask_union = mask_pred_p + mask_p union_matrix[index_pred, index] = torch.sum(mask_union) return union_matrix def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.7, td=0.03): """ Returns a clustering of hits -> cluster_index, based on the GravNet model output (predicted betas and cluster space coordinates) and the clustering parameters tbeta and td. Takes torch.Tensors as input. """ n_points = betas.size(0) select_condpoints = betas > tbeta # Get indices passing the threshold indices_condpoints = select_condpoints.nonzero() # Order them by decreasing beta value indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] # Assign points to condensation points # Only assign previously unassigned points (no overwriting) # Points unassigned at the end are bkg (-1) unassigned = torch.arange(n_points).to(betas.device) clustering = -1 * torch.ones(n_points, dtype=torch.long).to(betas.device) while len(indices_condpoints) > 0 and len(unassigned) > 0: index_condpoint = indices_condpoints[0] d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) assigned_to_this_condpoint = unassigned[d < td] clustering[assigned_to_this_condpoint] = index_condpoint[0] unassigned = unassigned[~(d < td)] # calculate indices_codpoints again indices_condpoints = find_condpoints(betas, unassigned, tbeta) return clustering def find_condpoints(betas, unassigned, tbeta): n_points = betas.size(0) select_condpoints = betas > tbeta device = betas.device mask_unassigned = torch.zeros(n_points).to(device) mask_unassigned[unassigned] = True select_condpoints = mask_unassigned.to(bool) * select_condpoints # Get indices passing the threshold indices_condpoints = select_condpoints.nonzero() # Order them by decreasing beta value indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] return indices_condpoints def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind, dic): list_intersection_E = [] # intersection_matrix_w = intersection_matrix_w particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) if torch.sum(particle_ids == 0) > 0: # removing also the MC particle corresponding to noise intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, 1:], 1, 0) row_ind = row_ind - 1 else: intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0) for i in range(0, len(col_ind)): list_intersection_E.append( intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1) ) if len(list_intersection_E) > 0: return torch.cat(list_intersection_E, dim=0) else: return 0 def plot_iou_matrix(iou_matrix, image_path, hdbscan=False): iou_matrix = torch.transpose(iou_matrix[1:, :], 1, 0) fig, ax = plt.subplots() iou_matrix = iou_matrix.detach().cpu().numpy() ax.matshow(iou_matrix, cmap=plt.cm.Blues) for i in range(0, iou_matrix.shape[1]): for j in range(0, iou_matrix.shape[0]): c = np.round(iou_matrix[j, i], 1) ax.text(i, j, str(c), va="center", ha="center") fig.savefig(image_path, bbox_inches="tight") if hdbscan: wandb.log({"iou_matrix_hdbscan": wandb.Image(image_path)}) else: wandb.log({"iou_matrix": wandb.Image(image_path)}) def match_showers( labels, dic, particle_ids, model_output, local_rank, i, path_save, pandora=False, tracks=False, hdbscan=False, ): iou_threshold = 0.25 shower_p_unique = torch.unique(labels) if torch.sum(labels == 0) == 0: shower_p_unique = torch.cat( ( torch.Tensor([0]).to(shower_p_unique.device).view(-1), shower_p_unique.view(-1), ), dim=0, ) e_hits = dic["graph"].ndata["e_hits"].view(-1) i_m, i_m_w = obtain_intersection_matrix( shower_p_unique, particle_ids, labels, dic, e_hits ) i_m = i_m.to(model_output.device) i_m_w = i_m_w.to(model_output.device) u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic) u_m = u_m.to(model_output.device) iou_matrix = i_m / u_m if torch.sum(particle_ids == 0) > 0: # removing also the MC particle corresponding to noise iou_matrix_num = ( torch.transpose(iou_matrix[1:, 1:], 1, 0).clone().detach().cpu().numpy() ) else: iou_matrix_num = ( torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy() ) iou_matrix_num[iou_matrix_num < iou_threshold] = 0 row_ind, col_ind = linear_sum_assignment(-iou_matrix_num) # Next three lines remove solutions where there is a shower that is not associated and iou it's zero (or less than threshold) mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0 row_ind = row_ind[mask_matching_matrix] col_ind = col_ind[mask_matching_matrix] if torch.sum(particle_ids == 0) > 0: row_ind = row_ind + 1 if i == 0 and local_rank == 0: if path_save is not None: if pandora: image_path = path_save + "/example_1_clustering_pandora.png" else: image_path = path_save + "/example_1_clustering.png" # plot_iou_matrix(iou_matrix, image_path, hdbscan) # row_ind are particles that are matched and col_ind the ind of preds they are matched to return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix def clustering_obtain_labels(X, betas, device): clustering = get_clustering(betas, X) map_from = list(np.unique(clustering.detach().cpu())) cluster_id = map(lambda x: map_from.index(x), clustering.detach().cpu()) clustering_ordered = torch.Tensor(list(cluster_id)).long() if torch.unique(clustering)[0] != -1: clustering = clustering_ordered + 1 else: clustering = clustering_ordered clustering = torch.Tensor(clustering.view(-1)).long().to(device) return clustering def hfdb_obtain_labels(X, device, eps=0.1): hdb = HDBSCAN(min_cluster_size=8, min_samples=8, cluster_selection_epsilon=eps).fit( X.detach().cpu() ) labels_hdb = hdb.labels_ + 1 labels_hdb = np.reshape(labels_hdb, (-1)) labels_hdb = torch.Tensor(labels_hdb).long().to(device) return labels_hdb def dbscan_obtain_labels(X, device): distance_scale = ( (torch.min(torch.abs(torch.min(X, dim=0)[0] - torch.max(X, dim=0)[0])) / 30) .view(-1) .detach() .cpu() .numpy()[0] ) db = DBSCAN(eps=distance_scale, min_samples=15).fit(X.detach().cpu()) # DBSCAN has clustering labels -1,0,.., our cluster 0 is noise so we add 1 labels = db.labels_ + 1 labels = np.reshape(labels, (-1)) labels = torch.Tensor(labels).long().to(device) return labels class CachedIndexList: def __init__(self, lst): self.lst = lst self.cache = {} def index(self, value): if value in self.cache: return self.cache[value] else: idx = self.lst.index(value) self.cache[value] = idx return idx def get_labels_pandora(tracks, dic, device): if tracks: labels_pandora = dic["graph"].ndata["pandora_pfo"].long() else: labels_pandora = dic["graph"].ndata["pandora_cluster"].long() labels_pandora = labels_pandora + 1 map_from = list(np.unique(labels_pandora.detach().cpu())) map_from = CachedIndexList(map_from) cluster_id = map(lambda x: map_from.index(x), labels_pandora.detach().cpu().numpy()) labels_pandora = torch.Tensor(list(cluster_id)).long().to(device) return labels_pandora