import dgl import torch import os from sklearn.cluster import DBSCAN from torch_scatter import scatter_max, scatter_add, scatter_mean import numpy as np import matplotlib.pyplot as plt from scipy.optimize import linear_sum_assignment import pandas as pd import wandb from src.layers.inference_oc import hfdb_obtain_labels def evaluate_efficiency_tracks( batch_g, model_output, embedded_outputs, y, local_rank, step, epoch, path_save, store=False, predict=False, ): number_of_showers_total = 0 batch_g.ndata["coords"] = model_output[:, 0:3] batch_g.ndata["beta"] = model_output[:, 3] batch_g.ndata["embedded_outputs"] = embedded_outputs graphs = dgl.unbatch(batch_g) batch_id = y[:, -1].view(-1) df_list = [] for i in range(0, len(graphs)): mask = batch_id == i dic = {} dic["graph"] = graphs[i] dic["part_true"] = y[mask] betas = torch.sigmoid(dic["graph"].ndata["beta"]) X = dic["graph"].ndata["coords"] clustering_mode = "dbscan" if clustering_mode == "clustering_normal": clustering = get_clustering(betas, X) elif clustering_mode == "dbscan": labels = hfdb_obtain_labels(X, betas.device, eps=0.05) particle_ids = torch.unique(dic["graph"].ndata["particle_number"]) shower_p_unique = torch.unique(labels) shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix = match_showers( labels, dic, particle_ids, model_output, local_rank, i, path_save, ) if len(row_ind) > 1: df_event, number_of_showers_total = generate_showers_data_frame( labels, dic, shower_p_unique, particle_ids, row_ind, col_ind, i_m_w, number_of_showers_total=number_of_showers_total, step=step, number_in_batch=i, ) # if len(shower_p_unique) < len(particle_ids): # print("storing event", local_rank, step, i) # torch.save( # dic, # path_save # + "/graphs_all_hdb/" # + str(local_rank) # + "_" # + str(step) # + "_" # + str(i) # + ".pt", # ) df_list.append(df_event) if len(df_list) > 0: df_batch = pd.concat(df_list) else: df_batch = [] if store: store_at_batch_end( path_save, df_batch, local_rank, step, epoch, predict=predict ) return df_batch def store_at_batch_end( path_save, df_batch, local_rank=0, step=0, epoch=None, predict=False, ): path_save_ = ( path_save + "/" + str(local_rank) + "_" + str(step) + "_" + str(epoch) + ".pt" ) if predict: print("STORING") df_batch = pd.concat(df_batch) df_batch.to_pickle(path_save_) log_efficiency(df_batch) def log_efficiency(df): # take the true showers non nan if len(df) > 0: mask = ~np.isnan(df["reco_showers_E"]) eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len( df["pred_showers_E"][mask].values ) wandb.log({"efficiency validation": eff}) def generate_showers_data_frame( labels, dic, shower_p_unique, particle_ids, row_ind, col_ind, i_m_w, number_of_showers_total=None, step=0, number_in_batch=0, ): e_pred_showers = 1.0 * scatter_add( torch.ones_like(labels).view(-1), labels.long(), ) e_reco_showers = scatter_add( torch.ones_like(labels).view(-1), dic["graph"].ndata["particle_number"].long(), ) e_reco_showers = e_reco_showers[1:] e_true_showers = dic["part_true"][:, 5] row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long() col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long() pred_showers = shower_p_unique index_matches = col_ind + 1 index_matches = index_matches.to(e_pred_showers.device).long() matched_es = torch.zeros_like(e_reco_showers) * (torch.nan) matched_es = matched_es.to(e_pred_showers.device) matched_es[row_ind] = e_pred_showers[index_matches] intersection_E = torch.zeros_like(e_reco_showers) * (torch.nan) ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind) intersection_E[row_ind] = ie_e.to(e_pred_showers.device) pred_showers[index_matches] = -1 pred_showers[ 0 ] = ( -1 ) # this takes into account that the class 0 for pandora and for dbscan is noise mask = pred_showers != -1 fake_showers_e = e_pred_showers[mask] fake_showers_showers_e_truw = torch.zeros((fake_showers_e.shape[0])) * (torch.nan) fake_showers_showers_e_truw = fake_showers_showers_e_truw.to(e_pred_showers.device) e_reco = torch.cat((e_reco_showers, fake_showers_showers_e_truw), dim=0) e_true = torch.cat((e_true_showers, fake_showers_showers_e_truw), dim=0) e_pred = torch.cat((matched_es, fake_showers_e), dim=0) e_pred_t = torch.cat( ( intersection_E, torch.zeros_like(fake_showers_e) * (torch.nan), ), dim=0, ) # print(e_reco.shape, e_pred.shape, e_pred_t.shape) d = { "reco_showers_E": e_reco.detach().cpu(), "true_showers_E": e_true.detach().cpu(), "pred_showers_E": e_pred.detach().cpu(), "e_pred_and_truth": e_pred_t.detach().cpu(), } df = pd.DataFrame(data=d) if number_of_showers_total is None: return df else: return df, number_of_showers_total def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits): len_pred_showers = len(shower_p_unique) intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to( shower_p_unique.device ) intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to( shower_p_unique.device ) for index, id in enumerate(particle_ids): counts = torch.zeros_like(labels) mask_p = dic["graph"].ndata["particle_number"] == id h_hits = e_hits.clone() counts[mask_p] = 1 h_hits[~mask_p] = 0 intersection_matrix[:, index] = scatter_add(counts, labels) intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device)) return intersection_matrix, intersection_matrix_w def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic): len_pred_showers = len(shower_p_unique) union_matrix = torch.zeros((len_pred_showers, len(particle_ids))) for index, id in enumerate(particle_ids): counts = torch.zeros_like(labels) mask_p = dic["graph"].ndata["particle_number"] == id for index_pred, id_pred in enumerate(shower_p_unique): mask_pred_p = labels == id_pred mask_union = mask_pred_p + mask_p union_matrix[index_pred, index] = torch.sum(mask_union) return union_matrix def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=0.5): """ Returns a clustering of hits -> cluster_index, based on the GravNet model output (predicted betas and cluster space coordinates) and the clustering parameters tbeta and td. Takes torch.Tensors as input. """ n_points = betas.size(0) select_condpoints = betas > tbeta # Get indices passing the threshold indices_condpoints = select_condpoints.nonzero() # Order them by decreasing beta value indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] # Assign points to condensation points # Only assign previously unassigned points (no overwriting) # Points unassigned at the end are bkg (-1) unassigned = torch.arange(n_points) clustering = -1 * torch.ones(n_points, dtype=torch.long) for index_condpoint in indices_condpoints: d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) assigned_to_this_condpoint = unassigned[d < td] clustering[assigned_to_this_condpoint] = index_condpoint[0] unassigned = unassigned[~(d < td)] return clustering def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind): list_intersection_E = [] # intersection_matrix_w = intersection_matrix_w intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0) for i in range(0, len(col_ind)): list_intersection_E.append( intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1) ) return torch.cat(list_intersection_E, dim=0) def plot_iou_matrix(iou_matrix, image_path): iou_matrix = torch.transpose(iou_matrix[1:, :], 1, 0) fig, ax = plt.subplots() iou_matrix = iou_matrix.detach().cpu().numpy() ax.matshow(iou_matrix, cmap=plt.cm.Blues) for i in range(0, iou_matrix.shape[1]): for j in range(0, iou_matrix.shape[0]): c = np.round(iou_matrix[j, i], 1) ax.text(i, j, str(c), va="center", ha="center") fig.savefig(image_path, bbox_inches="tight") wandb.log({"iou_matrix": wandb.Image(image_path)}) def match_showers( labels, dic, particle_ids, model_output, local_rank, i, path_save, ): iou_threshold = 0.1 shower_p_unique = torch.unique(labels) if torch.sum(labels == 0) == 0: shower_p_unique = torch.cat( ( torch.Tensor([0]).to(shower_p_unique.device).view(-1), shower_p_unique.view(-1), ), dim=0, ) # all hits weight the same e_hits = torch.ones_like(labels) i_m, i_m_w = obtain_intersection_matrix( shower_p_unique, particle_ids, labels, dic, e_hits ) i_m = i_m.to(model_output.device) i_m_w = i_m_w.to(model_output.device) u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic) u_m = u_m.to(model_output.device) iou_matrix = i_m / u_m iou_matrix_num = ( torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy() ) iou_matrix_num[iou_matrix_num < iou_threshold] = 0 row_ind, col_ind = linear_sum_assignment(-iou_matrix_num) # next three lines remove solutions where there is a shower that is not associated and iou it's zero (or less than threshold) mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0 row_ind = row_ind[mask_matching_matrix] col_ind = col_ind[mask_matching_matrix] if i == 0 and local_rank == 0: if path_save is not None: image_path = path_save + "/example_1_clustering.png" plot_iou_matrix(iou_matrix, image_path) # row_ind are particles that are matched and col_ind the ind of preds they are matched to return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix