jetclustering / src /layers /obtain_statistics.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
5.82 kB
import torch
from torch_scatter import scatter_max, scatter_add, scatter_mean
import numpy as np
import matplotlib.pyplot as plt
import os
def obtain_statistics_graph(stat_dict, y_all, g_all, pf=True):
import dgl
graphs = dgl.unbatch(g_all)
batch_id = y_all[:, -1].view(-1)
for i in range(0, len(graphs)):
mask = batch_id == i
y = y_all[mask]
g = graphs[i]
number_of_particles_event = len(y)
if pf:
energy_particles = y[:, 3]
else:
energy_particles = y[:, 3]
# obtain stats about particles and energy of the particles
stat_dict["freq_count_particles"][number_of_particles_event] = (
stat_dict["freq_count_particles"][number_of_particles_event] + 1
)
stat_dict["freq_count_energy"] = stat_dict["freq_count_energy"] + torch.histc(
energy_particles, bins=500, min=0.001, max=50
)
# obtain angle stats
# if pf:
# cluster_space_coords = g.ndata["pos_hits_xyz"]
# object_index = g.ndata["particle_number"].view(-1)
# x_alpha_sum = scatter_mean(cluster_space_coords, object_index.long(), dim=0)
# nVs = x_alpha_sum[1:] / torch.norm(
# x_alpha_sum[1:], p=2, dim=-1, keepdim=True
# )
# # compute cosine of the angles using dot product
# cos_ij = torch.einsum("ij,pj->ip", nVs, nVs)
# min_cos_per_particle = torch.min(torch.abs(cos_ij), dim=0)[0]
# stat_dict["freq_count_angle"] = stat_dict["freq_count_angle"] + torch.histc(
# min_cos_per_particle, bins=10, min=0, max=1.1
# )
# else:
eta = y[:, 0]
phi = y[:, 1]
len_y = len(eta)
dr_matrix = torch.sqrt(
torch.square(
torch.tile(eta.view(1, -1), (len_y, 1))
- torch.tile(eta.view(-1, 1), (1, len_y))
)
+ torch.square(
torch.tile(phi.view(1, -1), (len_y, 1))
- torch.tile(phi.view(-1, 1), (1, len_y))
)
)
device = y.device
dr_matrix = dr_matrix + torch.eye(len_y, len_y).to(device) * 10
min_cos_per_particle = torch.min(dr_matrix, dim=1)[0]
stat_dict["freq_count_angle"] = stat_dict["freq_count_angle"] + torch.histc(
min_cos_per_particle, bins=40, min=0, max=4
)
return stat_dict
def create_stats_dict(device):
bins_number_of_particles_event = torch.arange(0, 200, 1).to(device)
freq_count_particles = torch.zeros_like(bins_number_of_particles_event)
# the reason to not do log is that the histc only takes min, max, numbins and the other hist with bins is not supported in cuda
energy_event = torch.arange(0.001, 50, 0.1).to(
device
) # torch.exp(torch.arange(np.log(0.001), np.log(50), 0.1))
freq_count_energy = torch.zeros(len(energy_event)).to(device)
angle_distribution = torch.arange(0, 4 + 0.1, 0.1).to(device)
freq_count_angle = torch.zeros(len(angle_distribution) - 1).to(device)
stat_dict = {}
stat_dict["bins_number_of_particles_event"] = bins_number_of_particles_event
stat_dict["freq_count_particles"] = freq_count_particles
stat_dict["energy_event"] = energy_event
stat_dict["freq_count_energy"] = freq_count_energy
stat_dict["angle_distribution"] = angle_distribution
stat_dict["freq_count_angle"] = freq_count_angle
return stat_dict
def save_stat_dict(stat_dict, path):
path = path + "/stat_dict.pt"
torch.save(stat_dict, path)
def stacked_hist_plot(lst, lst_pandora, path_store, title, title_no_latex):
# lst is a list of arrays. plot them in a stacked histogram with the same X-axis
fig, ax = plt.subplots(len(lst), 1, figsize=(6, 13))
if len(lst) == 1:
ax = [ax]
binsE = [0, 5, 15, 35, 51]
for i in range(len(lst)):
if i == 0:
bins = np.linspace(-0.03, 0.03, 200)
else:
bins = np.linspace(-0.005, 0.005, 200)
ax[i].hist(lst[i], bins, histtype="step", label="ML", color="red", density=True)
if i < len(lst_pandora):
ax[i].hist(lst_pandora[i], bins, histtype="step", label="Pandora", color="blue", density=True)
ax[i].legend()
ax[i].grid()
ax[i].set_yscale("log")
ax[i].set_xlabel(r"$\Delta \phi$")
ax[i].set_title(title + " [{},{}] GeV".format(binsE[i], binsE[i+1]))
ax[i].title.set_size(15)
# set size of legend as well
ax[i].legend(prop={"size": 14})
#fig.suptitle(title)
fig.tight_layout()
fig.savefig(os.path.join(path_store, title_no_latex + "_angle_distributions.pdf"))
def plot_distributions(stat_dict, PATH_store, pf=False):
# energy per event
print(PATH_store)
fig, axs = plt.subplots(1, 3, figsize=(9, 3))
b = stat_dict["freq_count_energy"] / torch.sum(stat_dict["freq_count_energy"])
a = stat_dict["energy_event"]
a = a.detach().cpu()
b = b.detach().cpu()
axs[0].bar(a, b, width=0.2)
axs[0].set_title("Energy distribution")
b = stat_dict["freq_count_angle"] / torch.sum(stat_dict["freq_count_angle"])
a = stat_dict["angle_distribution"][:-1]
a = a.detach().cpu()
b = b.detach().cpu()
axs[1].bar(a, b, width=0.02)
axs[1].set_xlim([0, 1])
axs[1].set_title("Angle distribution")
# axs[1].set_ylim([0,1])
b = stat_dict["freq_count_particles"] / torch.sum(stat_dict["freq_count_particles"])
a = stat_dict["bins_number_of_particles_event"]
a = a.detach().cpu()
b = b.detach().cpu()
axs[2].bar(a, b)
axs[2].set_title("number of particles")
# fig.suptitle('Stats event')
fig.savefig(
PATH_store + "/stats.png",
bbox_inches="tight",
)