gregorkrzmanc's picture
.
e75a247
raw
history blame
4.88 kB
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
matplotlib.rc("font", size=35)
import matplotlib.pyplot as plt
import torch
from torch_scatter import scatter_sum, scatter_mean
def calculate_event_energy_resolution(df, pandora=False, full_vector=False):
true_e = torch.Tensor(df.true_showers_E.values)
mask_nan_true = np.isnan(df.true_showers_E.values)
true_e[mask_nan_true] = 0
batch_idx = df.number_batch
if pandora:
pred_E = df.pandora_calibrated_pfo.values
nan_mask = np.isnan(df.pandora_calibrated_pfo.values)
pred_E[nan_mask] = 0
pred_e1 = torch.tensor(pred_E).unsqueeze(1).repeat(1, 3)
pred_vect = torch.tensor(np.array(df.pandora_calibrated_pos.values.tolist()))
pred_vect[nan_mask] = 0
true_vect = torch.tensor(np.array(df.true_pos.values.tolist()))
true_vect[mask_nan_true] = 0
else:
pred_E = df.calibrated_E.values
nan_mask = np.isnan(df.calibrated_E.values)
# print(np.sum(nan_mask))
pred_E[nan_mask] = 0
pred_e1 = torch.tensor(pred_E).unsqueeze(1).repeat(1, 3)
pred_vect = torch.tensor(np.array(df.pred_pos_matched.values.tolist()))
pred_vect[nan_mask] = 0
true_vect = torch.tensor(np.array(df.true_pos.values.tolist()))
true_vect[mask_nan_true] = 0
batch_idx = torch.tensor(batch_idx.values).long()
pred_E = torch.tensor(pred_E)
true_jet_vect = scatter_sum(true_vect, batch_idx, dim=0)
pred_jet_vect = scatter_sum(pred_vect, batch_idx, dim=0)
true_E_jet = scatter_sum(torch.tensor(true_e), batch_idx)
pred_E_jet = scatter_sum(torch.tensor(pred_E), batch_idx)
true_jet_p = torch.norm(
true_jet_vect, dim=1
) # This is actually momentum resolution
pred_jet_p = torch.norm(pred_jet_vect, dim=1)
mass_true = torch.sqrt(torch.abs(true_E_jet**2 - true_jet_p**2))
mass_pred = torch.sqrt(torch.abs(pred_E_jet**2 - pred_jet_p**2))
mass_over_true = mass_pred / mass_true
return mass_over_true
def get_response_for_event_energy(matched_pandora, matched_):
mass_over_true_pandora = calculate_event_energy_resolution(
matched_pandora, True, True
)
decay_type = get_decay_type(matched_pandora)
mass_over_true_model = calculate_event_energy_resolution(matched_, False, True)
dic = {}
dic["mass_over_true_model"] = mass_over_true_model
dic["mass_over_true_pandora"] = mass_over_true_pandora
dic["decay_type"] = decay_type
return dic
def get_decay_type(sd_hgb1):
batch_number = sd_hgb1.number_batch.values
decay_type_list = []
for batch_id in range(0, int(np.max(batch_number)) + 1):
decay_type = determine_decay_type(sd_hgb1, batch_id)
decay_type_list.append(decay_type)
return torch.cat(decay_type_list)
def determine_decay_type(sd_hgb1, i):
pid_values = np.abs(sd_hgb1[sd_hgb1.number_batch == i].pid.values)
if len(pid_values) == 2:
decay_type = 0
charged = np.prod(pid_values == [211.0, 211])
elif len(pid_values) == 4 and np.count_nonzero(pid_values == 22.0) == 4:
decay_type = 1
neutral = np.prod(pid_values == [22.0, 22.0, 22.0, 22.0])
else:
decay_type = 2
return torch.Tensor([decay_type])
def plot_mass_resolution(event_res_dic, PATH_store):
mask_decay_charged = event_res_dic["decay_type"] == 0
fig, ax = plt.subplots()
ax.set_xlabel("M_pred/M_true")
ax.hist(
event_res_dic["mass_over_true_model"][mask_decay_charged],
bins=100,
histtype="step",
label="ML",
color="red",
density=True,
)
ax.hist(
event_res_dic["mass_over_true_pandora"][mask_decay_charged],
bins=100,
histtype="step",
label="Pandora",
color="blue",
density=True,
)
ax.grid()
ax.legend()
ax.set_xlim([0, 10])
fig.tight_layout()
fig.savefig(PATH_store + "mass_resolution_charged.pdf", bbox_inches="tight")
mask_decay_neutral = event_res_dic["decay_type"] == 1
fig, ax = plt.subplots()
ax.set_xlabel("M_pred/M_true")
ax.hist(
event_res_dic["mass_over_true_model"][mask_decay_neutral],
bins=100,
histtype="step",
label="ML",
color="red",
density=True,
)
ax.hist(
event_res_dic["mass_over_true_pandora"][mask_decay_neutral],
bins=100,
histtype="step",
label="Pandora",
color="blue",
density=True,
)
ax.grid()
ax.legend()
ax.set_xlim([0, 10])
fig.tight_layout()
fig.savefig(PATH_store + "mass_resolution_neutral.pdf", bbox_inches="tight")
def mass_Ks(matched_pandora, matched_, PATH_store):
event_res_dic = get_response_for_event_energy(matched_pandora, matched_)
plot_mass_resolution(event_res_dic, PATH_store)