jetclustering / src /layers /inference_oc_tracks.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
11.1 kB
import dgl
import torch
import os
from sklearn.cluster import DBSCAN
from torch_scatter import scatter_max, scatter_add, scatter_mean
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
import pandas as pd
import wandb
from src.layers.inference_oc import hfdb_obtain_labels
def evaluate_efficiency_tracks(
batch_g,
model_output,
embedded_outputs,
y,
local_rank,
step,
epoch,
path_save,
store=False,
predict=False,
):
number_of_showers_total = 0
batch_g.ndata["coords"] = model_output[:, 0:3]
batch_g.ndata["beta"] = model_output[:, 3]
batch_g.ndata["embedded_outputs"] = embedded_outputs
graphs = dgl.unbatch(batch_g)
batch_id = y[:, -1].view(-1)
df_list = []
for i in range(0, len(graphs)):
mask = batch_id == i
dic = {}
dic["graph"] = graphs[i]
dic["part_true"] = y[mask]
betas = torch.sigmoid(dic["graph"].ndata["beta"])
X = dic["graph"].ndata["coords"]
clustering_mode = "dbscan"
if clustering_mode == "clustering_normal":
clustering = get_clustering(betas, X)
elif clustering_mode == "dbscan":
labels = hfdb_obtain_labels(X, betas.device, eps=0.05)
particle_ids = torch.unique(dic["graph"].ndata["particle_number"])
shower_p_unique = torch.unique(labels)
shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix = match_showers(
labels,
dic,
particle_ids,
model_output,
local_rank,
i,
path_save,
)
if len(row_ind) > 1:
df_event, number_of_showers_total = generate_showers_data_frame(
labels,
dic,
shower_p_unique,
particle_ids,
row_ind,
col_ind,
i_m_w,
number_of_showers_total=number_of_showers_total,
step=step,
number_in_batch=i,
)
# if len(shower_p_unique) < len(particle_ids):
# print("storing event", local_rank, step, i)
# torch.save(
# dic,
# path_save
# + "/graphs_all_hdb/"
# + str(local_rank)
# + "_"
# + str(step)
# + "_"
# + str(i)
# + ".pt",
# )
df_list.append(df_event)
if len(df_list) > 0:
df_batch = pd.concat(df_list)
else:
df_batch = []
if store:
store_at_batch_end(
path_save, df_batch, local_rank, step, epoch, predict=predict
)
return df_batch
def store_at_batch_end(
path_save,
df_batch,
local_rank=0,
step=0,
epoch=None,
predict=False,
):
path_save_ = (
path_save + "/" + str(local_rank) + "_" + str(step) + "_" + str(epoch) + ".pt"
)
if predict:
print("STORING")
df_batch = pd.concat(df_batch)
df_batch.to_pickle(path_save_)
log_efficiency(df_batch)
def log_efficiency(df):
# take the true showers non nan
if len(df) > 0:
mask = ~np.isnan(df["reco_showers_E"])
eff = np.sum(~np.isnan(df["pred_showers_E"][mask].values)) / len(
df["pred_showers_E"][mask].values
)
wandb.log({"efficiency validation": eff})
def generate_showers_data_frame(
labels,
dic,
shower_p_unique,
particle_ids,
row_ind,
col_ind,
i_m_w,
number_of_showers_total=None,
step=0,
number_in_batch=0,
):
e_pred_showers = 1.0 * scatter_add(
torch.ones_like(labels).view(-1),
labels.long(),
)
e_reco_showers = scatter_add(
torch.ones_like(labels).view(-1),
dic["graph"].ndata["particle_number"].long(),
)
e_reco_showers = e_reco_showers[1:]
e_true_showers = dic["part_true"][:, 5]
row_ind = torch.Tensor(row_ind).to(e_pred_showers.device).long()
col_ind = torch.Tensor(col_ind).to(e_pred_showers.device).long()
pred_showers = shower_p_unique
index_matches = col_ind + 1
index_matches = index_matches.to(e_pred_showers.device).long()
matched_es = torch.zeros_like(e_reco_showers) * (torch.nan)
matched_es = matched_es.to(e_pred_showers.device)
matched_es[row_ind] = e_pred_showers[index_matches]
intersection_E = torch.zeros_like(e_reco_showers) * (torch.nan)
ie_e = obtain_intersection_values(i_m_w, row_ind, col_ind)
intersection_E[row_ind] = ie_e.to(e_pred_showers.device)
pred_showers[index_matches] = -1
pred_showers[
0
] = (
-1
) # this takes into account that the class 0 for pandora and for dbscan is noise
mask = pred_showers != -1
fake_showers_e = e_pred_showers[mask]
fake_showers_showers_e_truw = torch.zeros((fake_showers_e.shape[0])) * (torch.nan)
fake_showers_showers_e_truw = fake_showers_showers_e_truw.to(e_pred_showers.device)
e_reco = torch.cat((e_reco_showers, fake_showers_showers_e_truw), dim=0)
e_true = torch.cat((e_true_showers, fake_showers_showers_e_truw), dim=0)
e_pred = torch.cat((matched_es, fake_showers_e), dim=0)
e_pred_t = torch.cat(
(
intersection_E,
torch.zeros_like(fake_showers_e) * (torch.nan),
),
dim=0,
)
# print(e_reco.shape, e_pred.shape, e_pred_t.shape)
d = {
"reco_showers_E": e_reco.detach().cpu(),
"true_showers_E": e_true.detach().cpu(),
"pred_showers_E": e_pred.detach().cpu(),
"e_pred_and_truth": e_pred_t.detach().cpu(),
}
df = pd.DataFrame(data=d)
if number_of_showers_total is None:
return df
else:
return df, number_of_showers_total
def obtain_intersection_matrix(shower_p_unique, particle_ids, labels, dic, e_hits):
len_pred_showers = len(shower_p_unique)
intersection_matrix = torch.zeros((len_pred_showers, len(particle_ids))).to(
shower_p_unique.device
)
intersection_matrix_w = torch.zeros((len_pred_showers, len(particle_ids))).to(
shower_p_unique.device
)
for index, id in enumerate(particle_ids):
counts = torch.zeros_like(labels)
mask_p = dic["graph"].ndata["particle_number"] == id
h_hits = e_hits.clone()
counts[mask_p] = 1
h_hits[~mask_p] = 0
intersection_matrix[:, index] = scatter_add(counts, labels)
intersection_matrix_w[:, index] = scatter_add(h_hits, labels.to(h_hits.device))
return intersection_matrix, intersection_matrix_w
def obtain_union_matrix(shower_p_unique, particle_ids, labels, dic):
len_pred_showers = len(shower_p_unique)
union_matrix = torch.zeros((len_pred_showers, len(particle_ids)))
for index, id in enumerate(particle_ids):
counts = torch.zeros_like(labels)
mask_p = dic["graph"].ndata["particle_number"] == id
for index_pred, id_pred in enumerate(shower_p_unique):
mask_pred_p = labels == id_pred
mask_union = mask_pred_p + mask_p
union_matrix[index_pred, index] = torch.sum(mask_union)
return union_matrix
def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=0.5):
"""
Returns a clustering of hits -> cluster_index, based on the GravNet model
output (predicted betas and cluster space coordinates) and the clustering
parameters tbeta and td.
Takes torch.Tensors as input.
"""
n_points = betas.size(0)
select_condpoints = betas > tbeta
# Get indices passing the threshold
indices_condpoints = select_condpoints.nonzero()
# Order them by decreasing beta value
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()]
# Assign points to condensation points
# Only assign previously unassigned points (no overwriting)
# Points unassigned at the end are bkg (-1)
unassigned = torch.arange(n_points)
clustering = -1 * torch.ones(n_points, dtype=torch.long)
for index_condpoint in indices_condpoints:
d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1)
assigned_to_this_condpoint = unassigned[d < td]
clustering[assigned_to_this_condpoint] = index_condpoint[0]
unassigned = unassigned[~(d < td)]
return clustering
def obtain_intersection_values(intersection_matrix_w, row_ind, col_ind):
list_intersection_E = []
# intersection_matrix_w = intersection_matrix_w
intersection_matrix_wt = torch.transpose(intersection_matrix_w[1:, :], 1, 0)
for i in range(0, len(col_ind)):
list_intersection_E.append(
intersection_matrix_wt[row_ind[i], col_ind[i]].view(-1)
)
return torch.cat(list_intersection_E, dim=0)
def plot_iou_matrix(iou_matrix, image_path):
iou_matrix = torch.transpose(iou_matrix[1:, :], 1, 0)
fig, ax = plt.subplots()
iou_matrix = iou_matrix.detach().cpu().numpy()
ax.matshow(iou_matrix, cmap=plt.cm.Blues)
for i in range(0, iou_matrix.shape[1]):
for j in range(0, iou_matrix.shape[0]):
c = np.round(iou_matrix[j, i], 1)
ax.text(i, j, str(c), va="center", ha="center")
fig.savefig(image_path, bbox_inches="tight")
wandb.log({"iou_matrix": wandb.Image(image_path)})
def match_showers(
labels,
dic,
particle_ids,
model_output,
local_rank,
i,
path_save,
):
iou_threshold = 0.1
shower_p_unique = torch.unique(labels)
if torch.sum(labels == 0) == 0:
shower_p_unique = torch.cat(
(
torch.Tensor([0]).to(shower_p_unique.device).view(-1),
shower_p_unique.view(-1),
),
dim=0,
)
# all hits weight the same
e_hits = torch.ones_like(labels)
i_m, i_m_w = obtain_intersection_matrix(
shower_p_unique, particle_ids, labels, dic, e_hits
)
i_m = i_m.to(model_output.device)
i_m_w = i_m_w.to(model_output.device)
u_m = obtain_union_matrix(shower_p_unique, particle_ids, labels, dic)
u_m = u_m.to(model_output.device)
iou_matrix = i_m / u_m
iou_matrix_num = (
torch.transpose(iou_matrix[1:, :], 1, 0).clone().detach().cpu().numpy()
)
iou_matrix_num[iou_matrix_num < iou_threshold] = 0
row_ind, col_ind = linear_sum_assignment(-iou_matrix_num)
# next three lines remove solutions where there is a shower that is not associated and iou it's zero (or less than threshold)
mask_matching_matrix = iou_matrix_num[row_ind, col_ind] > 0
row_ind = row_ind[mask_matching_matrix]
col_ind = col_ind[mask_matching_matrix]
if i == 0 and local_rank == 0:
if path_save is not None:
image_path = path_save + "/example_1_clustering.png"
plot_iou_matrix(iou_matrix, image_path)
# row_ind are particles that are matched and col_ind the ind of preds they are matched to
return shower_p_unique, row_ind, col_ind, i_m_w, iou_matrix