Spaces:
Sleeping
Sleeping
import os | |
from tqdm import tqdm | |
import argparse | |
import pickle | |
from src.plotting.eval_matrix import matrix_plot, scatter_plot | |
from src.utils.paths import get_path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# This script produces the pt cutoff vs. f1 score | |
inputs = { | |
30: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_30.0", | |
40: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_40.0", | |
50: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_50.0", | |
60: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_60.0", | |
70: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_70.0", | |
80: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_80.0", | |
90: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_pt_90.0", | |
100: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset" | |
} | |
inputs = { | |
30: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_30.0", | |
40: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_40.0", | |
50: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_50.0", | |
60: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_60.0", | |
70: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_70.0", | |
80: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_80.0", | |
90: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_pt_90.0", | |
100: "Delphes_020425_test_PU_PFfix_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy", | |
} | |
''' | |
print("PLOTTING QCD") | |
inputs = { | |
30: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_30.0", | |
40: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_40.0", | |
50: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_50.0", | |
60: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_60.0", | |
70: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_70.0", | |
80: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_80.0", | |
90: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD_pt_90.0", | |
100: "QCD_test_part0/batch_eval_2k/DelphesPFfix_FullDataset_TrainDSstudy_QCD" | |
} | |
''' | |
files = { | |
key: pickle.load(open(os.path.join(get_path(value, "results"), "precision_recall.pkl"), "rb")) for key, value in inputs.items() | |
} | |
titles = {key: set(value.keys()) for key, value in files.items()} | |
# make a set of the intersections of titles | |
intersections = sorted(list(set.intersection(*titles.values()))) | |
titles_to_plot = { | |
"AK, R=0.8": ["AK8", "gray"], | |
"GT_R=0.8 LGATr_GP_IRC_S_50k_s12900, sc. (aug)": ["LGATr_GP_IRC_S", "red"], | |
"GT_R=0.8 LGATr_GP_50k_s25020, sc. (aug)": ["LGATr_GP", "purple"], | |
"GT_R=0.8 base_LGATr_s50000, sc.": ["LGATr", "orange"] | |
} # To plot different variations of the model | |
print("QCD") # colors= [{"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", "LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}, {"base_LGATr": "LGATr_900_03"}], | |
titles_to_plot = { | |
"AK, R=0.8": ["AK8", "gray"], | |
"GT_R=0.8 base_LGATr_s50000, sc.": ["LGATr_900_03", "orange"], | |
"GT_R=0.8 LGATr_QCD_s50000, sc.": ["LGATr_QCD", "purple"], | |
"GT_R=0.8 LGATr_700_07_s50000, sc.": ["LGATr_700_07", "red"], | |
"GT_R=0.8 LGATr_700_07+900_03_s50000, sc.": ["LGATr_700_07+900_03", "blue"], | |
"GT_R=0.8 LGATr_700_07+900_03+QCD_s50000, sc.": ["LGATr_700_07+900_03+QCD", "green"], | |
} | |
titles_to_plot = { | |
"AK, R=0.8": ["AK8", "gray"], | |
"GT_R=0.8 LGATr_GP_IRC_S_50k_s12900, sc. (aug)": ["LGATr_900_03", "orange"], | |
"GT_R=0.8 LGATr_GP_IRC_S_QCD_s24000, sc. (aug)": ["LGATr_QCD", "purple"], | |
"GT_R=0.8 LGATr_GP_IRC_S_700_07_s24000, sc. (aug)": ["LGATr_700_07", "red"], | |
"GT_R=0.8 LGATr_GP_IRC_S_700_07+900_03_s24000, sc. (aug)": ["LGATr_700_07+900_03", "blue"], | |
"GT_R=0.8 LGATr_GP_IRC_S_700_07+900_03+QCD_s24000, sc. (aug)": ["LGATr_700_07+900_03+QCD", "green"], | |
} | |
intersections = sorted(list(titles_to_plot.keys())) | |
output_dirs = [] | |
for _, value in inputs.items(): | |
output_dirs.append(get_path(value, "results")) | |
result = files[100][intersections[0]] | |
mediator_masses = sorted(list(result.keys())) | |
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]]))) | |
sz = 4 | |
#fig, ax = plt.subplots(len(inputs), len(titles_to_plot), figsize=(sz*len(titles_to_plot), sz*len(inputs))) | |
fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), sz*len(mediator_masses))) | |
figp, axp = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), sz*len(mediator_masses))) | |
figr, axr = plt.subplots(len(mediator_masses), len(r_invs), figsize=(sz*len(r_invs), sz*len(mediator_masses))) | |
if len(r_invs) == 1 and len(mediator_masses) == 1: | |
ax = np.array([[ax]]) | |
axp = np.array([[axp]]) | |
axr = np.array([[axr]]) | |
grids = set() | |
for i, mMed in enumerate(mediator_masses): | |
for j, rInv in enumerate(r_invs): | |
for k, title in enumerate(intersections): | |
label, color = titles_to_plot[title] | |
pts = sorted(list(inputs.keys())) | |
precisions = [] | |
recalls = [] | |
f1_scores = [] | |
for pt in pts: | |
precision, recall = files[pt][title][mMed][20][rInv] | |
precisions.append(precision) | |
recalls.append(recall) | |
f1_score = 2 * precision * recall / (precision + recall) | |
f1_scores.append(f1_score) | |
ax[i, j].plot(pts, f1_scores, ".-", label=label, color=color) | |
axp[i, j].plot(pts, precisions, ".-", label=label, color=color) | |
axr[i, j].plot(pts, recalls, ".-", label=label, color=color) | |
ax[i, j].set_title(f"$m_{{Z'}} = {mMed}$ GeV, $r_{{inv.}}$ = {rInv}") | |
ax[i, j].set_xlabel("$p_T^{cutoff}$") | |
axp[i, j].set_title(f"$m_{{Z'}} = {mMed}$ GeV, $r_{{inv.}}$ = {rInv}") | |
axp[i, j].set_xlabel("$p_T^{cutoff}$") | |
axr[i, j].set_title(f"$m_{{Z'}} = {mMed}$ GeV, $r_{{inv.}}$ = {rInv}") | |
axr[i, j].set_xlabel("$p_T^{cutoff}$") | |
ax[i, j].set_ylabel("$F_1$ score") | |
axp[i, j].set_ylabel("Precision") | |
axr[i, j].set_ylabel("Recall") | |
ax[i, j].legend() | |
axp[i, j].legend() | |
axr[i, j].legend() | |
if (i, j) not in grids: | |
ax[i, j].grid() | |
axp[i, j].grid() | |
axr[i, j].grid() | |
grids.add((i, j)) | |
for f in output_dirs: | |
fig.tight_layout() | |
fname = os.path.join(f, "pt_cutoff_vs_f1_score.pdf") | |
fig.tight_layout() | |
fig.savefig(fname) | |
print("saved to", fname) | |
fname = os.path.join(f, "pt_cutoff_vs_precision.pdf") | |
figp.tight_layout() | |
figp.savefig(fname) | |
fname = os.path.join(f, "pt_cutoff_vs_recall.pdf") | |
figr.tight_layout() | |
figr.savefig(fname) | |