Spaces:
Sleeping
Sleeping
import torch | |
from torch_scatter import scatter_max, scatter_add, scatter_mean | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
def obtain_statistics_graph(stat_dict, y_all, g_all, pf=True): | |
import dgl | |
graphs = dgl.unbatch(g_all) | |
batch_id = y_all[:, -1].view(-1) | |
for i in range(0, len(graphs)): | |
mask = batch_id == i | |
y = y_all[mask] | |
g = graphs[i] | |
number_of_particles_event = len(y) | |
if pf: | |
energy_particles = y[:, 3] | |
else: | |
energy_particles = y[:, 3] | |
# obtain stats about particles and energy of the particles | |
stat_dict["freq_count_particles"][number_of_particles_event] = ( | |
stat_dict["freq_count_particles"][number_of_particles_event] + 1 | |
) | |
stat_dict["freq_count_energy"] = stat_dict["freq_count_energy"] + torch.histc( | |
energy_particles, bins=500, min=0.001, max=50 | |
) | |
# obtain angle stats | |
# if pf: | |
# cluster_space_coords = g.ndata["pos_hits_xyz"] | |
# object_index = g.ndata["particle_number"].view(-1) | |
# x_alpha_sum = scatter_mean(cluster_space_coords, object_index.long(), dim=0) | |
# nVs = x_alpha_sum[1:] / torch.norm( | |
# x_alpha_sum[1:], p=2, dim=-1, keepdim=True | |
# ) | |
# # compute cosine of the angles using dot product | |
# cos_ij = torch.einsum("ij,pj->ip", nVs, nVs) | |
# min_cos_per_particle = torch.min(torch.abs(cos_ij), dim=0)[0] | |
# stat_dict["freq_count_angle"] = stat_dict["freq_count_angle"] + torch.histc( | |
# min_cos_per_particle, bins=10, min=0, max=1.1 | |
# ) | |
# else: | |
eta = y[:, 0] | |
phi = y[:, 1] | |
len_y = len(eta) | |
dr_matrix = torch.sqrt( | |
torch.square( | |
torch.tile(eta.view(1, -1), (len_y, 1)) | |
- torch.tile(eta.view(-1, 1), (1, len_y)) | |
) | |
+ torch.square( | |
torch.tile(phi.view(1, -1), (len_y, 1)) | |
- torch.tile(phi.view(-1, 1), (1, len_y)) | |
) | |
) | |
device = y.device | |
dr_matrix = dr_matrix + torch.eye(len_y, len_y).to(device) * 10 | |
min_cos_per_particle = torch.min(dr_matrix, dim=1)[0] | |
stat_dict["freq_count_angle"] = stat_dict["freq_count_angle"] + torch.histc( | |
min_cos_per_particle, bins=40, min=0, max=4 | |
) | |
return stat_dict | |
def create_stats_dict(device): | |
bins_number_of_particles_event = torch.arange(0, 200, 1).to(device) | |
freq_count_particles = torch.zeros_like(bins_number_of_particles_event) | |
# the reason to not do log is that the histc only takes min, max, numbins and the other hist with bins is not supported in cuda | |
energy_event = torch.arange(0.001, 50, 0.1).to( | |
device | |
) # torch.exp(torch.arange(np.log(0.001), np.log(50), 0.1)) | |
freq_count_energy = torch.zeros(len(energy_event)).to(device) | |
angle_distribution = torch.arange(0, 4 + 0.1, 0.1).to(device) | |
freq_count_angle = torch.zeros(len(angle_distribution) - 1).to(device) | |
stat_dict = {} | |
stat_dict["bins_number_of_particles_event"] = bins_number_of_particles_event | |
stat_dict["freq_count_particles"] = freq_count_particles | |
stat_dict["energy_event"] = energy_event | |
stat_dict["freq_count_energy"] = freq_count_energy | |
stat_dict["angle_distribution"] = angle_distribution | |
stat_dict["freq_count_angle"] = freq_count_angle | |
return stat_dict | |
def save_stat_dict(stat_dict, path): | |
path = path + "/stat_dict.pt" | |
torch.save(stat_dict, path) | |
def stacked_hist_plot(lst, lst_pandora, path_store, title, title_no_latex): | |
# lst is a list of arrays. plot them in a stacked histogram with the same X-axis | |
fig, ax = plt.subplots(len(lst), 1, figsize=(6, 13)) | |
if len(lst) == 1: | |
ax = [ax] | |
binsE = [0, 5, 15, 35, 51] | |
for i in range(len(lst)): | |
if i == 0: | |
bins = np.linspace(-0.03, 0.03, 200) | |
else: | |
bins = np.linspace(-0.005, 0.005, 200) | |
ax[i].hist(lst[i], bins, histtype="step", label="ML", color="red", density=True) | |
if i < len(lst_pandora): | |
ax[i].hist(lst_pandora[i], bins, histtype="step", label="Pandora", color="blue", density=True) | |
ax[i].legend() | |
ax[i].grid() | |
ax[i].set_yscale("log") | |
ax[i].set_xlabel(r"$\Delta \phi$") | |
ax[i].set_title(title + " [{},{}] GeV".format(binsE[i], binsE[i+1])) | |
ax[i].title.set_size(15) | |
# set size of legend as well | |
ax[i].legend(prop={"size": 14}) | |
#fig.suptitle(title) | |
fig.tight_layout() | |
fig.savefig(os.path.join(path_store, title_no_latex + "_angle_distributions.pdf")) | |
def plot_distributions(stat_dict, PATH_store, pf=False): | |
# energy per event | |
print(PATH_store) | |
fig, axs = plt.subplots(1, 3, figsize=(9, 3)) | |
b = stat_dict["freq_count_energy"] / torch.sum(stat_dict["freq_count_energy"]) | |
a = stat_dict["energy_event"] | |
a = a.detach().cpu() | |
b = b.detach().cpu() | |
axs[0].bar(a, b, width=0.2) | |
axs[0].set_title("Energy distribution") | |
b = stat_dict["freq_count_angle"] / torch.sum(stat_dict["freq_count_angle"]) | |
a = stat_dict["angle_distribution"][:-1] | |
a = a.detach().cpu() | |
b = b.detach().cpu() | |
axs[1].bar(a, b, width=0.02) | |
axs[1].set_xlim([0, 1]) | |
axs[1].set_title("Angle distribution") | |
# axs[1].set_ylim([0,1]) | |
b = stat_dict["freq_count_particles"] / torch.sum(stat_dict["freq_count_particles"]) | |
a = stat_dict["bins_number_of_particles_event"] | |
a = a.detach().cpu() | |
b = b.detach().cpu() | |
axs[2].bar(a, b) | |
axs[2].set_title("number of particles") | |
# fig.suptitle('Stats event') | |
fig.savefig( | |
PATH_store + "/stats.png", | |
bbox_inches="tight", | |
) | |