jetclustering / scripts /plot_eval_count_matched_quarks.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
80.6 kB
import torch
from itertools import chain, combinations
import os
from tqdm import tqdm
import argparse
import pickle
from src.plotting.eval_matrix import matrix_plot, scatter_plot, multiple_matrix_plot, ax_tiny_histogram
from src.utils.paths import get_path
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict
#### Plotting functions
from matplotlib_venn import venn3
import matplotlib.pyplot as plt
from copy import copy
def plot_venn3_from_index_dict(ax, data_dict, set_labels=('Set 0', 'Set 1', 'Set 2'), set_colors=("orange", "purple", "gray"), remove_max=True):
"""
Generate a 3-set Venn diagram from a dictionary where keys are strings of '0', '1', and '2'
indicating set membership, and values are counts.
Parameters:
- data_dict (dict): Dictionary with keys like '0', '01', '012', etc.
- set_labels (tuple): Labels for the three sets.
- remove_max: if true, it will remove
"""
# Mapping of set index combinations to venn3 region codes
index_to_region = {
'100': '100', # Only in Set 0
'010': '010', # Only in Set 1
'001': '001', # Only in Set 2
'110': '110', # In Set 0 and Set 1
'101': '101', # In Set 0 and Set 2
'011': '011', # In Set 1 and Set 2
'111': '111', # In all three
}
# Initialize region counts
venn_counts = {region: 0 for region in index_to_region.values()}
max_value = 0
for key in data_dict:
if data_dict[key] > max_value and key != "":
max_value = data_dict[key]
print("Max val", max_value)
data_dict = copy(data_dict)
new_data_dict = {}
for key in data_dict:
if remove_max and data_dict[key] == max_value:
# #data_dict[key] = 0
# del data_dict[key]
continue
else:
new_data_dict[key] = data_dict[key]
data_dict = new_data_dict
print("data dict", data_dict)
# Convert data_dict keys to binary keys for region mapping
for k, v in data_dict.items():
binary_key = ''.join(['1' if str(i) in k else '0' for i in range(3)])
if binary_key in index_to_region:
venn_counts[index_to_region[binary_key]] += v
# Plotting
#plt.figure(figsize=(8, 8))
del venn_counts['111']
venn = venn3(subsets=venn_counts, set_labels=set_labels, set_colors=set_colors, alpha=0.5, ax=ax)
venn.get_label_by_id("111").set_text(max_value)
#plt.title("3-Set Venn Diagram from Index Dictionary")
#plt.show()
### Change this to make custom plots highlighting differences between different models (the histograms of pt_pred/pt_true, eta_pred-eta_true, and phi_pred-phi_true)
histograms_dict = {
"": [{"base_LGATr": 50000, "base_Tr": 50000 , "base_GATr": 50000, "AK8": 50000}, {"base_LGATr": "orange", "base_Tr": "blue", "base_GATr": "green", "AK8": "gray"}],
"LGATr_comparison": [{"base_LGATr": 50000, "LGATr_GP_IRC_S_50k": 9960, "LGATr_GP_50k": 9960, "AK8": 50000, "LGATr_GP_IRC_SN_50k": 24000}, {"base_LGATr": "orange", "LGATr_GP_IRC_S_50k": "red", "LGATr_GP_50k": "purple", "LGATr_GP_IRC_SN_50k": "pink", "AK8": "gray"}],
"LGATr_comparsion_DifferentTrainingDS": [{"base_LGATr": 50000, "LGATr_700_07": 50000, "LGATr_QCD": 50000, "LGATr_700_07+900_03": 50000, "LGATr_700_07+900_03+QCD": 50000, "AK8": 50000}, {"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", "LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}]
}
# This is a dictionary that contains the models and their colors for plotting - to plot the F1 scores etc. of the models
results_dict = {
"LGATr_comparison_DifferentTrainingDS":
[{"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue",
"LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}, {"base_LGATr": "LGATr_900_03"}], # 2nd dict in list is rename dict
"LGATr_comparison": [{"base_LGATr": "orange", "LGATr_GP_IRC_S_50k": "red", "LGATr_GP_50k": "purple", "LGATr_GP_IRC_SN_50k": "pink", "AK8": "gray"},
{"base_LGATr": "LGATr", "LGATr_GP_IRC_S_50k": "LGATr_GP_IRC_S", "LGATr_GP_50k": "LGATr_GP", "LGATr_GP_IRC_SN_50k": "LGATr_GP_IRC_SN"}], # 2nd dict in list is rename dict
"LGATr_comparison_QCDtrain": [{"LGATr_QCD": "orange", "LGATr_GP_IRC_S_QCD": "red", "LGATr_GP_QCD": "purple", "LGATr_GP_IRC_SN_QCD": "pink", "AK8": "gray"},
{"LGATr_QCD": "LGATr", "LGATr_GP_IRC_S_QCD": "LGATr_GP_IRC_S", "LGATr_GP_QCD": "LGATr_GP", "LGATr_GP_IRC_SN_QCD": "LGATr_GP_IRC_SN"}], # 2nd dict in list is rename dict
"LGATr_comparison_GP_training": [
{"LGATr_GP_QCD": "purple", "LGATr_GP_700_07": "red", "LGATr_GP_700_07+900_03": "blue", "LGATr_GP_700_07+900_03+QCD": "green", "LGATr_GP_50k": "orange", "AK8": "gray"},
{"LGATr_GP_QCD": "QCD", "LGATr_GP_700_07": "700_07", "LGATr_GP_700_07+900_03": "700_07+900_03" , "LGATr_GP_50k": "900_03", "LGATr_GP_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict
],
"LGATr_comparison_GP_IRC_S_training": [
{"LGATr_GP_IRC_S_QCD": "purple", "LGATr_GP_IRC_S_700_07": "red", "LGATr_GP_IRC_S_700_07+900_03": "blue", "LGATr_GP_IRC_S_700_07+900_03+QCD": "green", "LGATr_GP_IRC_S_50k": "orange", "AK8": "gray"},
{"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_700_07": "700_07", "LGATr_GP_IRC_S_700_07+900_03": "700_07+900_03", "LGATr_GP_IRC_S_50k": "900_03", "LGATr_GP_IRC_S_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict
],
"LGATr_comparison_GP_IRC_SN_training": [
{"LGATr_GP_IRC_SN_QCD": "purple", "LGATr_GP_IRC_SN_700_07": "red", "LGATr_GP_IRC_SN_700_07+900_03": "blue", "LGATr_GP_IRC_SN_700_07+900_03+QCD": "green", "LGATr_GP_IRC_SN_50k": "orange", "AK8": "gray"},
{"LGATr_GP_IRC_SN_QCD": "QCD", "LGATr_GP_IRC_SN_700_07": "700_07", "LGATr_GP_IRC_SN_700_07+900_03": "700_07+900_03", "LGATr_GP_IRC_SN_50k": "900_03", "LGATr_GP_IRC_SN_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict
]
}
'''
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_19_21_29_06_946": "LGATr_GP_QCD",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_19_21_38_20_376": "LGATr_GP_700_07",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_13_12_54_359": "LGATr_GP_700_07+900_03+QCD",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_13_13_00_503": "LGATr_GP_700_07+900_03",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_15_29_30_29": "LGATr_GP_IRC_S_700_07+900_03",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_15_29_28_959": "LGATr_GP_IRC_S_700_07+900_03+QCD",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_20_15_11_35_476": "LGATr_GP_IRC_S_700_07",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_20_15_11_20_735": "LGATr_GP_IRC_S_QCD",
'''
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=False, default="scouting_PFNano_signals2/SVJ_hadronic_std/batch_eval/objectness_score")
parser.add_argument("--threshold-obj-score", "-os-threshold", type=float, default=-1)
thresholds = np.linspace(0.1, 1, 20)
# also add 100 points between 0 and 0.1 at the beginning
thresholds = np.concatenate([np.linspace(0, 0.1, 100), thresholds])
args = parser.parse_args()
path = get_path(args.input, "results")
models = sorted([x for x in os.listdir(path) if not os.path.isfile(os.path.join(path, x))])
models = [x for x in models if "AKX" not in x]
figures_all = {} # title to the f1 score figure to plot
figures_all_sorted = {} # model used -> step -> level -> f1 figure
print("Models:", models)
radius = {
"LGATr_R10": 1.0,
"LGATr_R09": 0.9,
"LGATr_rinv_03_m_900": 0.8,
"LGATr_R06": 0.6,
"LGATr_R07": 0.7,
"LGATr_R11": 1.1,
"LGATr_R12": 1.2,
"LGATr_R13": 1.3,
"LGATr_R14": 1.4,
"LGATr_R20": 2.0,
"LGATr_R25": 2.5
}
comments = {
"Eval_params_study_2025_02_17_13_30_50": ", tr. on 07_700",
"Eval_objectness_score_2025_02_12_15_34_33": ", tr. on 03_900, GT=all",
"Eval_objectness_score_2025_02_18_08_48_13": ", tr. on 03_900, GT=closest",
"Eval_objectness_score_2025_02_14_11_10_14": ", tr. on 03_900, GT=closest",
"Eval_objectness_score_2025_02_21_14_51_07": ", tr. on 07_700",
"Eval_objectness_score_2025_02_10_14_59_49": ", tr. on 03_900, GT=all",
"Eval_objectness_score_2025_02_23_19_26_25": ", tr. on all, GT=closest",
"Eval_objectness_score_2025_02_23_21_04_33": ", tr. on 03_900, GT=closest"
}
out_file_PR = os.path.join(get_path(args.input, "results"), "precision_recall.pdf")
out_file_PRf1 = os.path.join(get_path(args.input, "results"), "f1_score_sorted.pdf")
out_file_PG = os.path.join(get_path(args.input, "results"), "PLoverGL.pdf")
if args.threshold_obj_score != -1:
out_file_PR_OS = os.path.join(get_path(args.input, "results"), f"precision_recall_with_obj_score.pdf")
out_file_avg_number_matched_quarks = os.path.join(get_path(args.input, "results"), "avg_number_matched_quarks.pdf")
def get_plots_for_params(mMed, mDark, rInv, result_PR_thresholds):
precisions = []
recalls = []
f1_scores = []
for i in range(len(thresholds)):
if result_PR_thresholds[mMed][mDark][rInv][i][1] == 0:
precisions.append(0)
else:
precisions.append(
result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][1])
if result_PR_thresholds[mMed][mDark][rInv][i][2] == 0:
recalls.append(0)
else:
recalls.append(
result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][2])
for i in range(len(thresholds)):
if precisions[i] + recalls[i] == 0:
f1_scores.append(0)
else:
f1_scores.append(2 * precisions[i] * recalls[i] / (precisions[i] + recalls[i]))
return precisions, recalls, f1_scores
sz = 5
nplots = 9
# Now make 3 plots, one for mMed=700,r_inv=0.7; one for mMed=700,r_inv=0.5; one for mMed=700,r_inv=0.3
###fig, ax = plt.subplots(3, 3, figsize=(3 * sz, 3 * sz))
'''fig, ax = plt.subplots(3, nplots, figsize=(nplots*sz, 3*sz))
for mi, mass in enumerate([700, 900, 1500]):
start_idx = mi*3
for i0, rinv in enumerate([0.3, 0.5, 0.7]):
i = start_idx + i0
# 0 is precision, 1 is recall, 2 is f1 score
ax[0, i].set_title(f"r_inv={rinv}, m_med={mass} GeV")
ax[1, i].set_title(f"r_inv={rinv}, m_med={mass} GeV")
ax[2, i].set_title(f"r_inv={rinv}, m_med={mass} GeV")
ax[0, i].set_ylabel("Precision")
ax[1, i].set_ylabel("Recall")
ax[2, i].set_ylabel("F1 score")
ax[0, i].grid()
ax[1, i].grid()
ax[2, i].grid()
ylims = {} # key: j and i
default_ylims = [1, 0]
for j, model in enumerate(models):
result_PR_thresholds = os.path.join(path, model, "count_matched_quarks", "result_PR_thresholds.pkl")
if not os.path.exists(result_PR_thresholds):
continue
run_config = pickle.load(open(os.path.join(path, model, "run_config.pkl"), "rb"))
result_PR_thresholds = pickle.load(open(result_PR_thresholds, "rb"))
if mass not in result_PR_thresholds:
continue
if rinv not in result_PR_thresholds[mass]:
continue6
precisions, recalls, f1_scores = get_plots_for_params(mass, 20, rinv, result_PR_thresholds)
if not run_config["gt_radius"] == 0.8:
continue
label = "R={} gl.f.={} {}".format(run_config["gt_radius"], run_config.get("global_features_obj_score", False), comments.get(run_config["run_name"], run_config["run_name"]))
scatter_plot(ax[0, i], thresholds, precisions, label=label)
scatter_plot(ax[1, i], thresholds, recalls, label=label)
scatter_plot(ax[2, i], thresholds, f1_scores, label=label)
#ylims[0] = [min(ylims[0][0], min(precisions)), max(ylims[0][1], max(precisions))]
#ylims[1] = [min(ylims[1][0], min(recalls)), max(ylims[1][1], max(recalls))]
#ylims[2] = [min(ylims[2][0], min(f1_scores)), max(ylims[2][1], max(f1_scores))]
filt = thresholds < 0.2
precisions = np.array(precisions)[filt]
recalls = np.array(recalls)[filt]
f1_scores = np.array(f1_scores)[filt]
if (i, 0) not in ylims:
ylims[(i, 0)] = default_ylims
upper_factor = 1.01
lower_factor = 0.99
ylims[(i, 0)] = [min(ylims[(i, 0)][0], min(precisions)*lower_factor), max(ylims[(i, 0)][1], max(precisions)*upper_factor)]
if (i, 1) not in ylims:
ylims[(i, 1)] = default_ylims
ylims[(i, 1)] = [min(ylims[(i, 1)][0], min(recalls)*lower_factor), max(ylims[(i, 1)][1], max(recalls)*upper_factor)]
if (i, 2) not in ylims:
ylims[(i, 2)] = default_ylims
ylims[(i, 2)] = [min(ylims[(i, 2)][0], min(f1_scores)*lower_factor), max(ylims[(i, 2)][1], max(f1_scores)*upper_factor)]
for j in range(3):
ax[j, i].set_ylim(ylims[(i, j)])
ax[j, i].legend()
ax[j, i].set_xlim([0, 0.2])
ax[j, i].set_xlim([0, 0.2])
ax[j, i].set_xlim([0, 0.2])
# now adjust the ylim so that the plots are more readable
fig.tight_layout()
fig.savefig(os.path.join(get_path(args.input, "results"), "precision_recall_thresholds.pdf"))
print("Saved to", os.path.join(get_path(args.input, "results"), "precision_recall_thresholds.pdf"))'''
import wandb
api = wandb.Api()
def get_run_by_name(name):
clust_suffix = ""
if name.endswith("FT"):
#remove FT from the end
name = name[:-2]
clust_suffix = "FT"
if name.endswith("FT1"):
#remove FT from the end # min-samples 1 min-cluster-size 2 epsilon 0.3
name = name[:-3]
clust_suffix = "FT1"
if name.endswith("10_5"):
name = name[:-4]
clust_suffix = "10_5"
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
if runs.length != 1:
return None
return runs[0], clust_suffix
def get_run_config(run_name):
r, clust_suffix = get_run_by_name(run_name)
if r is None:
print("Getting info from run", run_name, "failed")
return None, None
config = r.config
result = {}
if config["parton_level"]:
prefix = "PL"
result["level"] = "PL"
result["level_idx"] = 0
elif config["gen_level"]:
prefix = "GL"
result["level"] = "GL"
result["level_idx"] = 2
else:
prefix = "sc."
result["level"] = "scouting"
result["level_idx"] = 1
if config["augment_soft_particles"]:
result["ghosts"] = True
#result["level"] += "+ghosts"
gt_r = config["gt_radius"]
if config.get("augment_soft_particles", False):
prefix += " (aug)" # ["LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_25_820", "LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_26_400"]
training_datasets = {
"LGATr_training_NoPID_10_16_64_0.8_AllData_2025_02_28_13_42_59": "all",
"LGATr_training_NoPID_10_16_64_0.8_2025_02_28_12_42_59": "900_03",
"LGATr_training_NoPID_10_16_64_2.0_2025_02_28_12_48_58": "900_03",
"LGATr_training_NoPID_10_16_64_0.8_700_07_2025_02_28_13_01_59": "700_07",
"LGATr_training_NoPIDGL_10_16_64_0.8_2025_03_17_20_05_04": "900_03_GenLevel",
"LGATr_training_NoPIDGL_10_16_64_2.0_2025_03_17_20_05_04": "900_03_GenLevel",
"Transformer_training_NoPID_10_16_64_2.0_2025_03_03_17_00_38": "900_03_T",
"Transformer_training_NoPID_10_16_64_0.8_2025_03_03_15_55_50": "900_03_T",
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_2025_03_27_12_46_12_740": "900_03+SoftAug",
"LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_2025_03_28_10_43_36_81": "900_03+SoftAugVM",
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_2025_03_28_10_43_37_44": "900_03+SoftAugVM",
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_25_820": "900_03+qcap05",
"LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_26_400": "900_03+qcap05",
"LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_1e-2_2025_03_29_14_58_38_650": "pt 1e-2",
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_1e-2_2025_03_29_14_58_36_446": "pt 1e-2",
"LGATr_pt_1e-2_500part_2025_04_01_16_49_08_406": "500_pt_1e-2_PLFT",
"LGATr_pt_1e-2_500part_2025_04_01_21_14_07_350": "500_pt_1e-2_PLFT",
"LGATr_pt_1e-2_500part_NoQMin_2025_04_03_23_15_17_745": "500_1e-2_scFT",
"LGATr_pt_1e-2_500part_NoQMin_2025_04_03_23_15_35_810": "500_1e-2_scFT",
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_2025_04_04_12_57_51_536": "10_1000_1e-2_scFT",
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_2025_04_04_12_57_47_788": "10_1000_1e-2_scFT",
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_CW0_2025_04_04_15_30_16_839": "10_1000_1e-2_CW0",
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_CW0_2025_04_04_15_30_20_113": "10_1000_1e-2_CW0",
"debug_IRC_loss_weighted100_plus_ghosts_2025_04_08_22_40_33_972": "IRC_short_debug",
"debug_IRC_loss_weighted100_plus_ghosts_2025_04_09_13_48_55_569": "IRC",
"debug_IRC_loss_weighted100_plus_ghosts_Qmin05_2025_04_09_14_45_51_381": "IRC_qmin05",
"LGATr_500part_NOQMin_2025_04_09_21_53_37_210": "500part_NOQMin_reprod",
"IRC_loss_Split_and_Noise_alternate_Aug_2025_04_14_11_10_21_788": "IRC_Aug_S+N",
"IRC_loss_Split_and_Noise_alternate_NoAug_2025_04_11_16_15_48_955": "IRC_S+N",
"LGATr_training_NoPID_Delphes_10_16_64_0.8_2025_04_17_18_07_38_405": "DelphesTrain",
"Delphes_IRC_aug_2025_04_19_11_16_17_130": "DelphesTrain+IRC",
"LGATr_500part_NOQMin_Delphes_2025_04_19_11_15_24_417": "DelphesTrain+ghosts",
"Delphes_IRC_aug_SplitOnly_2025_04_20_15_50_33_553": "DelphesTrain+IRC_SplitOnly",
"Delphes_IRC_NOAug_SplitOnly_2025_04_21_12_58_36_99": "Delphes_IRC_NoAug_SplitOnly",
"Delphes_IRC_NOAug_SplitAndNoise_2025_04_21_19_32_08_865": "Delphes_IRC_NoAug_S+N",
"CONT_Delphes_IRC_aug_SplitOnly_2025_04_21_12_53_27_730": "IRC_aug_SplitOnly_ContFrom14k",
"Transformer_training_NoPID_Delphes_PU_10_16_64_0.8_2025_05_03_18_37_01_188": "base_Tr_Old",
"LGATr_training_NoPID_Delphes_PU_PFfix_10_16_64_0.8_2025_05_03_18_35_53_134": "base_LGATr",
"GATr_training_NoPID_Delphes_PU_10_16_64_0.8_2025_05_03_18_35_48_163": "base_GATr_Old",
"Transformer_training_NoPID_Delphes_PU_CoordFix_10_16_64_0.8_2025_05_05_13_05_20_755": "base_Tr",
"GATr_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_05_16_24_13_579": "base_GATr_SD",
"GATr_training_NoPID_Delphes_PU_CoordFix_10_16_64_0.8_2025_05_05_13_06_27_898": "base_GATr",
"LGATr_Aug_2025_05_06_10_08_05_956": "LGATr_GP",
"Delphes_Aug_IRCSplit_CONT_2025_05_07_11_00_18_422": "LGATr_GP_IRC_S",
"Delphes_Aug_IRC_Split_and_Noise_2025_05_07_14_43_13_968": "LGATr_GP_IRC_SN",
"Transformer_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_05_16_24_19_936": "base_Tr_SD",
"LGATr_training_NoPID_Delphes_PU_PFfix_SmallDS_10_16_64_0.8_2025_05_05_16_24_16_127": "base_LGATr_SD",
"Delphes_Aug_IRCSplit_2025_05_06_10_09_00_567": "LGATr_GP_IRC_S",
"GATr_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_09_15_34_13_531": "base_GATr_SD",
"Transformer_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_09_15_56_50_216": "base_Tr_SD",
"LGATr_training_NoPID_Delphes_PU_PFfix_SmallDS_10_16_64_0.8_2025_05_09_15_56_50_875": "base_LGATr_SD",
"Delphes_Aug_IRCSplit_50k_from10k_2025_05_11_14_08_49_675": "LGATr_GP_IRC_S_50k",
"LGATr_Aug_50k_2025_05_09_15_25_32_34": "LGATr_GP_50k",
"Delphes_Aug_IRCSplit_50k_2025_05_09_15_22_38_956": "LGATr_GP_IRC_S_50k",
"LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_16_21_04_26_937": "LGATr_700_07+900_03+QCD",
"LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_16_21_04_26_991": "LGATr_700_07+900_03",
"LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_16_19_46_57_48": "LGATr_QCD",
"LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_16_19_44_46_795": "LGATr_700_07",
"Delphes_Aug_IRCSplit_50k_SN_from3kFT_2025_05_16_14_07_29_474": "LGATr_GP_IRC_SN_50k",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_19_21_29_06_946": "LGATr_GP_QCD",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_19_21_38_20_376": "LGATr_GP_700_07",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_13_12_54_359": "LGATr_GP_700_07+900_03+QCD",
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_13_13_00_503": "LGATr_GP_700_07+900_03",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_15_29_30_29": "LGATr_GP_IRC_S_700_07+900_03",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_15_29_28_959": "LGATr_GP_IRC_S_700_07+900_03+QCD",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_20_15_11_35_476": "LGATr_GP_IRC_S_700_07",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_20_15_11_20_735": "LGATr_GP_IRC_S_QCD",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_24_23_00_54_948": "LGATr_GP_IRC_SN_QCD",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_24_23_00_56_910": "LGATr_GP_IRC_SN_700_07+900_03+QCD",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_24_23_01_01_212": "LGATr_GP_IRC_SN_700_07+900_03",
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_24_23_01_07_703": "LGATr_GP_IRC_SN_700_07",
}
train_name = config["load_from_run"]
ckpt_step = config["ckpt_step"]
print("train name", train_name)
if train_name not in training_datasets:
print("!! unknown run", train_name)
training_dataset = training_datasets.get(train_name, train_name) + "_s" + str(ckpt_step) + clust_suffix
if "plptfilt01" in run_name.lower():
training_dataset += "_PLPtFiltMinPt01" # min pt 0.1
elif "noplfilter" in run_name.lower():
training_dataset += "_noPLFilter"
elif "noplptfilter" in run_name.lower():
training_dataset += "_noPLPtFilter" # actually there was a 0.5 pt cut in the ntuplizer, removed by plptfilt01
elif "nopletafilter" in run_name.lower():
training_dataset += "_noPLEtaFilter"
result["GT_R"] = gt_r
result["training_dataset"] = training_dataset
result["training_dataset_nostep"] = training_datasets.get(train_name, train_name) + clust_suffix
result["ckpt_step"] = ckpt_step
return f"GT_R={gt_r} {training_dataset}, {prefix}", result
def flatten_list(lst):# lst is like [[0,0],[1,1]...]
#return [item for sublist in lst for item in sublist]
return list(chain.from_iterable(lst))
sz = 5
ak_path = os.path.join(path, "AKX", "count_matched_quarks")
result_PR_AKX = pickle.load(open(os.path.join(ak_path, "result_PR_AKX.pkl"), "rb"))
result_jet_props_akx = pickle.load(open(os.path.join(ak_path, "result_jet_properties_AKX.pkl"), "rb"))
result_qj_akx = pickle.load(open(os.path.join(ak_path, "result_quark_to_jet.pkl"), "rb"))
result_dq_pt_akx = pickle.load(open(os.path.join(ak_path, "result_pt_dq.pkl"), "rb"))
result_dq_mc_pt_akx = pickle.load(open(os.path.join(ak_path, "result_pt_mc_gt.pkl"), "rb"))
result_dq_props_akx = pickle.load(open(os.path.join(ak_path, "result_props_dq.pkl"), "rb"))
try:
result_PR_AKX_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_PR_AKX.pkl"), "rb"))
result_qj_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_quark_to_jet.pkl"), "rb"))
result_dq_mc_pt_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_pt_mc_gt.pkl"), "rb"))
result_dq_pt_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_pt_dq.pkl"), "rb"))
result_dq_props_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_props_dq.pkl"), "rb"))
except FileNotFoundError:
print("FileNotFoundError")
result_PR_AKX_PL = result_PR_AKX
try:
result_PR_AKX_GL = pickle.load(open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_PR_AKX.pkl"), "rb"))
result_qj_akx_GL = pickle.load(open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_quark_to_jet.pkl"), "rb"))
result_dq_mc_pt_akx_GL = pickle.load(
open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_pt_mc_gt.pkl"), "rb"))
result_dq_pt_akx_GL = pickle.load(
open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_pt_dq.pkl"), "rb"))
result_dq_props_akx_GL = pickle.load(
open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_props_dq.pkl"), "rb"))
except FileNotFoundError:
print("FileNotFoundError")
result_PR_AKX_GL = result_PR_AKX
#plot_only = ["LGATr_GP", "LGATr_GP_IRC_S", "LGATr_GP_IRC_SN", "LGATr_GP_50k", "LGATr_GP_IRC_S_50k"]
plot_only = []
radius = [0.8]
def select_radius(d, radius, depth=3):
# from the dictionary, select radius at the level
if depth == 0:
return d[radius]
return {key: select_radius(d[key], radius, depth - 1) for key in d}
if len(models): # temporarily do not plot this one
#fig, ax = plt.subplots(3, len(plot_only) + len(radius)*2, figsize=(sz * (len(plot_only)+len(radius)*2), sz * 3))
# three columns: PL, GL, scouting for each model
for i, model in tqdm(enumerate(sorted(models))):
output_path = os.path.join(path, model, "count_matched_quarks")
if not os.path.exists(os.path.join(output_path, "result.pkl")):
print("Result not exists for model", model)
continue
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
#result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb"))
#result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb"))
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb"))
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
#matrix_plot(result, "Blues", "Avg. matched dark quarks / event").savefig(os.path.join(output_path, "avg_matched_dark_quarks.pdf"), ax=ax[0, i])
#matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event").savefig(os.path.join(output_path, "avg_unmatched_jets.pdf"), ax=ax[1, i])
#matrix_plot(result_PR, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0], ax=ax[0, i])
#matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1], ax=ax[1, i])
#matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax[2, i])
print("Getting run config for model", model)
run_config_title, run_config = get_run_config(model)
print("RC title", run_config_title)
if run_config is None:
print("Skipping", model)
continue
#ax[0, i].set_title(run_config_title)
#ax[1, i].set_title(run_config_title)
#ax[2, i].set_title(run_config_title)
li = run_config["level_idx"]
#ax_f1[i, li].set_title(run_config_title)
#matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax_f1[i, li])
figures_all[run_config_title] = result_PR
print(model, run_config_title)
td, gtr, level, tdns = run_config["training_dataset"], run_config["GT_R"], run_config["level_idx"], run_config["training_dataset_nostep"]
if tdns in plot_only or not len(plot_only):
td = "R=" + str(gtr) + " " + td
if td not in figures_all_sorted:
figures_all_sorted[td] = {}
figures_all_sorted[td][level] = figures_all[run_config_title]
result_AKX_current = select_radius(result_PR_AKX, 0.8)
result_AKX_PL = select_radius(result_PR_AKX_PL, 0.8)
result_AKX_GL = select_radius(result_PR_AKX_GL, 0.8)
figures_all_sorted["AK8"]: {
0: result_AKX_PL,
1: result_AKX_current,
2: result_AKX_GL
}
for i, R in enumerate(radius):
result_PR_AKX_current = select_radius(result_PR_AKX, R)
#matrix_plot(result_PR_AKX_current, "Oranges", "Precision (N matched dark quarks / N predicted jets)",
# metric_comp_func=lambda r: r[0], ax=ax[0, i+len(models)])
#matrix_plot(result_PR_AKX_current, "Reds", "Recall (N matched dark quarks / N dark quarks)",
# metric_comp_func=lambda r: r[1], ax=ax[1, i+len(models)])
#matrix_plot(result_PR_AKX_current, "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]),
# ax=ax[2, i+len(models)])
#ax[0, i+len(models)].set_title(f"AK, R={R}")
#ax[1, i+len(models)].set_title(f"AK, R={R}")
#ax[2, i+len(models)].set_title(f"AK, R={R}")
t = f"AK, R={R}"
figures_all[t] = result_PR_AKX_current
for i, R in enumerate(radius):
result_PR_AKX_current = select_radius(result_PR_AKX_PL, R)
#matrix_plot(result_PR_AKX_current, "Oranges", "Precision (N matched dark quarks / N predicted jets)",
# metric_comp_func=lambda r: r[0], ax=ax[0, i+len(models)+len(radius)])
#matrix_plot(result_PR_AKX_current, "Reds", "Recall (N matched dark quarks / N dark quarks)",
# metric_comp_func=lambda r: r[1], ax=ax[1, i+len(models)+len(radius)])
#matrix_plot(result_PR_AKX_current, "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]),
# ax=ax[2, i+len(models)+len(radius)])
#ax[0, i+len(models)+len(radius)].set_title(f"AK PL, R={R}")
#ax[1, i+len(models)+len(radius)].set_title(f"AK PL, R={R}")
#ax[2, i+len(models)+len(radius)].set_title(f"AK PL, R={R}")
figures_all[f"AK PL, R={R}"] = result_PR_AKX_current
for i, R in enumerate(radius):
result_PR_AKX_current = select_radius(result_PR_AKX_GL, R)
figures_all[f"AK GL, R={R}"] = result_PR_AKX_current
#fig.tight_layout()
#fig.savefig(out_file_PR)
#print("Saved to", out_file_PR)
#fig_f1.tight_layout().463
#fig_f1.savefig(out_file_PRf1)
pickle.dump(figures_all, open(out_file_PR.replace(".pdf", ".pkl"), "wb"))
figures_all_sorted["AK8"] = {
0: select_radius(result_PR_AKX_PL, 0.8),
1: select_radius(result_PR_AKX, 0.8),
2: select_radius(result_PR_AKX_GL, 0.8)
}
text_level = ["PL", "PFCands", "GL"]
fig_f1, ax_f1 = plt.subplots(len(figures_all_sorted), 3, figsize=(sz * 2.5, sz * len(figures_all_sorted)))
if len(figures_all_sorted) == 1:
ax_f1 = np.array([ax_f1])
for i in range(len(figures_all_sorted)):
model = list(figures_all_sorted.keys())[i]
renames = {
"R=0.8 base_LGATr_s50000": "LGATr",
"R=0.8 LGATr_GP_50k_s25020": "LGATr_GP",
"R=0.8 LGATr_GP_IRC_S_50k_s12900": "LGATr_GP_IRC_S",
"AK8": "AK8",
"R=0.8 LGATr_GP_IRC_SN_50k_s22020": "LGATr_GP_IRC_SN"
}
for j in range(3):
if j in figures_all_sorted[model]:
if j in figures_all_sorted[model]:
matrix_plot(figures_all_sorted[model][j], "Purples", r"$F_1$ score",
metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax_f1[i, j], is_qcd="qcd" in path.lower())
ax_f1[i, j].set_title(renames.get(model, model) + " "+ text_level[j])
ax_f1[i, j].set_xlabel("$m_{Z'}$")
ax_f1[i, j].set_ylabel("$r_{inv.}$")
fig_f1.tight_layout()
fig_f1.savefig(out_file_PRf1)
import pandas as pd
# plot QCD results:
def get_qcd_results(i):
# i=0: precision, i=1: recall, i=2: f1 score
qcd_results = {}
for model in figures_all_sorted:
qcd_results[model] = {}
for level in figures_all_sorted[model]:
r = figures_all_sorted[model][level][0][0][0]
r = [float(x) for x in r] # append f1 score
r.append(r[0]*2*r[1] / (r[0]+r[1]))
qcd_results[model][text_level[level]] = r[i]
return pd.DataFrame(qcd_results).T
if "qcd" in path.lower():
print("Precision:")
print(get_qcd_results(0))
print("----------------")
print("Recall:")
print(get_qcd_results(1))
print("----------------")
print("F1 score:")
print(get_qcd_results(2))
## Now do the GT R vs metrics plots
oranges = plt.get_cmap("Oranges")
reds = plt.get_cmap("Reds")
purples = plt.get_cmap("Purples")
mDark = 20
if "qcd" in path.lower():
print("QCD events")
mDark=0
to_plot = {} # training dataset -> rInv -> mMed -> level -> "f1score" -> value
to_plot_steps = {} # training dataset -> rInv -> mMed -> level -> step -> value
to_plot_v2 = {} # level -> rInv -> mMed -> {"model": [P,R]}
quark_to_jet = {} # level -> rInv -> mMed -> model -> quark to jet assignment list
mc_gt_pt_of_dq = {}
pt_of_dq = {}
props_of_dq = {"eta": {}, "phi": {}} # Properties of dark quarks: eta and phi
results_all = {}
results_all_ak = {}
jet_properties = {} # training dataset -> rInv -> mMed -> level -> step -> jet property dict
jet_properties_ak = {} # rInv -> mMed -> level -> radius
plotting_hypotheses = [[700, 0.7], [700, 0.5], [700, 0.3], [900, 0.3], [900, 0.7]]
if "qcd" in path.lower():
plotting_hypotheses = [[0,0]]
sz_small = 5
for j, model in enumerate(models):
_, rc = get_run_config(model)
if rc is None or model in ["Eval_eval_19March2025_pt1e-2_500particles_FT_PL_2025_04_02_14_28_33_421FT", "Eval_eval_19March2025_pt1e-2_500particles_FT_PL_2025_04_02_14_47_23_671FT", "Eval_eval_19March2025_small_aug_vanishing_momentum_2025_03_28_11_45_16_582", "Eval_eval_19March2025_small_aug_vanishing_momentum_2025_03_28_11_46_26_326"]:
print("Skipping", model)
continue
td = rc["training_dataset"]
td_raw = rc["training_dataset_nostep"]
level = rc["level"]
r = rc["GT_R"]
output_path = os.path.join(path, model, "count_matched_quarks")
if not os.path.exists(os.path.join(output_path, "result_PR.pkl")):
continue
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
result_QJ = pickle.load(open(os.path.join(output_path, "result_quark_to_jet.pkl"), "rb"))
result_jet_props = pickle.load(open(os.path.join(output_path, "result_jet_properties.pkl"), "rb"))
result_MC_PT = pickle.load(open(os.path.join(output_path, "result_pt_mc_gt.pkl"), "rb"))
result_PT_DQ = pickle.load(open(os.path.join(output_path, "result_pt_dq.pkl"), "rb"))
result_DQ_props = pickle.load(open(os.path.join(output_path, "result_props_dq.pkl"), "rb"))
print(level)
if td not in to_plot:
to_plot[td] = {}
results_all[td] = {}
if td_raw not in to_plot_steps:
to_plot_steps[td_raw] = {}
jet_properties[td_raw] = {}
level_idx = ["PL", "scouting", "GL"].index(level)
if level_idx not in to_plot_v2:
to_plot_v2[level_idx] = {}
quark_to_jet[level_idx] = {}
pt_of_dq[level_idx] = {}
mc_gt_pt_of_dq[level_idx] = {}
for prop in props_of_dq:
props_of_dq[prop][level_idx] = {}
for mMed_h in result_PR:
if mMed_h not in to_plot_v2[level_idx]:
to_plot_v2[level_idx][mMed_h] = {}
quark_to_jet[level_idx][mMed_h] = {}
pt_of_dq[level_idx][mMed_h] = {}
mc_gt_pt_of_dq[level_idx][mMed_h] = {}
for prop in props_of_dq:
props_of_dq[prop][level_idx][mMed_h] = {}
if mMed_h not in to_plot_steps[td_raw]:
to_plot_steps[td_raw][mMed_h] = {}
jet_properties[td_raw][mMed_h] = {}
if mMed_h not in results_all[td]:
results_all[td][mMed_h] = {mDark: {}}
for rInv_h in result_PR[mMed_h][mDark]:
if rInv_h not in to_plot_v2[level_idx][mMed_h]:
to_plot_v2[level_idx][mMed_h][rInv_h] = {}
quark_to_jet[level_idx][mMed_h][rInv_h] = {}
pt_of_dq[level_idx][mMed_h][rInv_h] = {}
mc_gt_pt_of_dq[level_idx][mMed_h][rInv_h] = {}
for prop in props_of_dq:
props_of_dq[prop][level_idx][mMed_h][rInv_h] = {}
if rInv_h not in to_plot_steps[td_raw][mMed_h]:
to_plot_steps[td_raw][mMed_h][rInv_h] = {}
jet_properties[td_raw][mMed_h][rInv_h] = {}
if level not in to_plot_steps[td_raw][mMed_h][rInv_h]:
to_plot_steps[td_raw][mMed_h][rInv_h][level] = {}
jet_properties[td_raw][mMed_h][rInv_h][level] = {}
if rInv_h not in results_all[td][mMed_h][mDark]:
results_all[td][mMed_h][mDark][rInv_h] = {}
#for level in ["PL+ghosts", "GL+ghosts", "scouting+ghosts"]:
if level not in results_all[td][mMed_h][mDark][rInv_h]:
results_all[td][mMed_h][mDark][rInv_h][level] = {}
precision = result_PR[mMed_h][mDark][rInv_h][0]
recall = result_PR[mMed_h][mDark][rInv_h][1]
f1score = 2 * precision * recall / (precision + recall)
to_plot_v2[level_idx][mMed_h][rInv_h][td_raw] = [precision, recall]
quark_to_jet[level_idx][mMed_h][rInv_h][td_raw] = result_QJ[mMed_h][mDark][rInv_h]
pt_of_dq[level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_PT_DQ[mMed_h][mDark][rInv_h])
mc_gt_pt_of_dq[level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_MC_PT[mMed_h][mDark][rInv_h])
for prop in props_of_dq:
props_of_dq[prop][level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_DQ_props[prop][mMed_h][mDark][rInv_h])
#print("qj", quark_to_jet[level_idx][mMed_h][rInv_h][td_raw])
if r not in results_all[td][mMed_h][mDark][rInv_h][level]:
results_all[td][mMed_h][mDark][rInv_h][level][r] = f1score
ckpt_step = rc["ckpt_step"]
to_plot_steps[td_raw][mMed_h][rInv_h][level][ckpt_step] = f1score
jet_properties[td_raw][mMed_h][rInv_h][level][ckpt_step] = result_jet_props[mMed_h][mDark][rInv_h]
m_Meds = []
r_invs = []
for key in to_plot_steps:
m_Meds += list(to_plot_steps[key].keys())
for key2 in to_plot_steps[key]:
r_invs += list(to_plot_steps[key][key2].keys())
m_Meds = sorted(list(set(m_Meds)))
r_invs = sorted(list(set(r_invs)))
result_AKX_current = select_radius(result_PR_AKX, 0.8)
result_AKX_PL = select_radius(result_PR_AKX_PL, 0.8)
result_AKX_GL = select_radius(result_PR_AKX_GL, 0.8)
result_AKX_jet_properties = select_radius(result_jet_props_akx, 0.8)
jet_properties["AK8"] = {}
result_AKX_current_QJ = select_radius(result_qj_akx, 0.8)
result_AKX_PL_QJ = select_radius(result_qj_akx_PL, 0.8)
result_AKX_GL_QJ = select_radius(result_qj_akx_GL, 0.8)
result_AKX_current_pt_dq = select_radius(result_dq_pt_akx, 0.8)
result_AKX_PL_pt_dq = select_radius(result_dq_pt_akx_PL, 0.8)
result_AKX_GL_pt_dq = select_radius(result_dq_pt_akx_GL, 0.8)
result_AKX_current_MCpt_dq = select_radius(result_dq_mc_pt_akx, 0.8)
result_AKX_PL_MCpt_dq = select_radius(result_dq_mc_pt_akx_PL, 0.8)
result_AKX_GL_MCpt_dq = select_radius(result_dq_mc_pt_akx_GL, 0.8)
result_AKX_current_props_dq = select_radius(result_dq_props_akx, 0.8, depth=4)
result_AKX_PL_props_dq = select_radius(result_dq_props_akx_PL, 0.8, depth=4)
result_AKX_GL_props_dq = select_radius(result_dq_props_akx_GL, 0.8, depth=4)
from tqdm import tqdm
for mMed_h in result_AKX_jet_properties:
for rInv_h in result_AKX_jet_properties[mMed_h][mDark]:
if 0 in to_plot_v2:
to_plot_v2[0][mMed_h][rInv_h]["AK8"] = result_AKX_PL[mMed_h][mDark][rInv_h]
to_plot_v2[1][mMed_h][rInv_h]["AK8"] = result_AKX_current[mMed_h][mDark][rInv_h]
to_plot_v2[2][mMed_h][rInv_h]["AK8"] = result_AKX_GL[mMed_h][mDark][rInv_h]
quark_to_jet[0][mMed_h][rInv_h]["AK8"] = result_AKX_PL_QJ[mMed_h][mDark][rInv_h]
quark_to_jet[1][mMed_h][rInv_h]["AK8"] = result_AKX_current_QJ[mMed_h][mDark][rInv_h]
quark_to_jet[2][mMed_h][rInv_h]["AK8"] = result_AKX_GL_QJ[mMed_h][mDark][rInv_h]
pt_of_dq[0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_pt_dq[mMed_h][mDark][rInv_h])
pt_of_dq[1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_pt_dq[mMed_h][mDark][rInv_h])
pt_of_dq[2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_pt_dq[mMed_h][mDark][rInv_h])
mc_gt_pt_of_dq[0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_MCpt_dq[mMed_h][mDark][rInv_h])
mc_gt_pt_of_dq[1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_MCpt_dq[mMed_h][mDark][rInv_h])
mc_gt_pt_of_dq[2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_MCpt_dq[mMed_h][mDark][rInv_h])
for k in props_of_dq:
props_of_dq[k][0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_props_dq[k][mMed_h][mDark][rInv_h])
props_of_dq[k][1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_props_dq[k][mMed_h][mDark][rInv_h])
props_of_dq[k][2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_props_dq[k][mMed_h][mDark][rInv_h])
if mMed_h not in jet_properties["AK8"]:
jet_properties["AK8"][mMed_h] = {}
if rInv_h not in jet_properties["AK8"][mMed_h]:
jet_properties["AK8"][mMed_h][rInv_h] = {}
jet_properties["AK8"][mMed_h][rInv_h] = {"scouting": {50000: result_AKX_jet_properties[mMed_h][mDark][rInv_h]}}
rename_results_dict = {
"LGATr_comparison_DifferentTrainingDS": "base",
"LGATr_comparison_GP_training": "GP",
"LGATr_comparison_GP_IRC_S_training": "GP_IRC_S",
"LGATr_comparison_GP_IRC_SN_training": "GP_IRC_SN"
}
hypotheses_to_plot = [[0,0],[700,0.7],[700,0.5],[700,0.3]]
def powerset(iterable):
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def get_label_from_superset(lbl, labels_rename, labels):
if lbl == '':
return "Missed by all"
r = [labels[int(i)] for i in lbl]
r = [labels_rename.get(l,l) for l in r]
if len(r) == 2 and "QCD" in r and "900_03" in r:
return "Found by both models but not AK"
if len(r) == 3:
return "Found by all"
return ", ".join(r)
for hyp_m, hyp_rinv in hypotheses_to_plot:
if 0 not in to_plot_v2:
continue # Not for the lower-pt thresholds, where only GL and PL are available
if hyp_m not in to_plot_v2[0] or hyp_rinv not in to_plot_v2[0][hyp_m]:
continue
# plot here the venn diagrams
labels = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"]
labels_global = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"]
labels_rename = {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_50k": "900_03"}
fig_venn, ax_venn = plt.subplots(6, 3, figsize=(5 * 3, 5 * 6)) # the bottom ones are for pt of the DQ, pt of the MC GT, pt of MC GT / pt of DQ, eta, and phi distributions
fig_venn1, ax_venn1 = plt.subplots(6, 2, figsize=(5*2, 5*6)) # Only the PFCands-level, with full histogram on the left and density on the right
for level in range(3):
#labels = list(results_dict["LGATr_comparison_GP_IRC_S_training"][0].keys())
label_combination_to_number = {} # fill it with all possible label combinations e.g. if there are 3 labels: "NA", "0", "1", "2", "01", "012", "12", "02"
powerset_str = ["".join([str(x) for x in sorted(list(a))]) for a in powerset(range(len(labels)))]
set_to_count = {key: 0 for key in powerset_str}
set_to_stats = {key: {"pt_dq": [], "pt_mc_t": [], "pt_mc_t_dq_ratio": [], "eta": [], "phi": []} for key in powerset_str}
label_to_result = {}
#label_to_stats = {"pt_dq": , "pt_mc_t": [], "pt_mc_t_dq_ratio": [], "eta": [], "phi": []}
n_dq = 999999999
for j, label in enumerate(labels):
r = flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label])
n_dq = min(n_dq, len(r)) # Find the minimum number of dark quarks in all labels
for j, label in enumerate(labels):
r = torch.tensor(flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label]))
r = (r != -1) # Whether quark no. X is caught or not
label_to_result[j] = r.tolist()[:n_dq]
#r = torch.tensor(flatten_list(pt_of_dq[level][hyp_m][hyp_rinv][label]))
#r = r[:n_dq]
#label_to_stats["pt_dq"].append(r.tolist())
#r1 = torch.tensor(flatten_list(mc_gt_pt_of_dq[level][hyp_m][hyp_rinv][label]))
#r1 = r1[:n_dq]
#label_to_stats["pt_mc_t"].append(r1.tolist())
#r2 = r1 / r
#r2 = r2[:n_dq]
#label_to_stats["pt_mc_t_dq_ratio"].append(r2.tolist())
#r_eta = torch.tensor(flatten_list(props_of_dq["eta"][level][hyp_m][hyp_rinv][label]))
#r_eta = r_eta[:n_dq]
#label_to_stats["eta"].append(r_eta.tolist())
##r_phi = torch.tensor(flatten_list(props_of_dq["phi"][level][hyp_m][hyp_rinv][label]))
#r_phi = r_phi[:n_dq]
#label_to_stats["phi"].append(r_phi.tolist())
assert len(label_to_result[j]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[j])}"
#n_dq = min(n_dq, len(r))
#for j, label in enumerate(labels):
# assert len(label_to_result[j]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[j])}"
for c in tqdm(range(n_dq)):
belonging_to_set = ""
for j, label in enumerate(labels):
if label_to_result[j][c] == 1:
belonging_to_set += str(j)
set_to_count[belonging_to_set] += 1
#for key in label_to_stats:
# for idx in belonging_to_set:
# idx_int = int(idx) # e.g. "0", "1" etc.
# set_to_stats[belonging_to_set]
for j, label in enumerate(labels):
current_dq_pt = pt_of_dq[level][hyp_m][hyp_rinv][label][c]
current_mc_gt_pt = mc_gt_pt_of_dq[level][hyp_m][hyp_rinv][label][c]
current_dq_eta = props_of_dq["eta"][level][hyp_m][hyp_rinv][label][c]
current_dq_phi = props_of_dq["phi"][level][hyp_m][hyp_rinv][label][c]
set_to_stats[belonging_to_set]["pt_dq"].append(current_dq_pt)
set_to_stats[belonging_to_set]["pt_mc_t"].append(current_mc_gt_pt)
set_to_stats[belonging_to_set]["pt_mc_t_dq_ratio"].append(current_mc_gt_pt/current_dq_pt)
set_to_stats[belonging_to_set]["eta"].append(current_dq_eta)
set_to_stats[belonging_to_set]["phi"].append(current_dq_phi)
#print("set_to_count for level", level, ":", set_to_count, "labels:", labels)
title = f"$m_{{Z'}}={hyp_m}$ GeV, $r_{{inv.}}={hyp_rinv}$, {text_level[level]} (missed by all: {set_to_count['']}) "
if hyp_m == 0 and hyp_rinv == 0:
title = f"QCD, {text_level[level]} (missed by all: {set_to_count['']})"
ax_venn[0, level].set_title(title)
plot_venn3_from_index_dict(ax_venn[0, level], set_to_count, set_labels=[labels_rename.get(l,l) for l in labels], set_colors=["orange", "gray", "red"])
if level == 1: #reco-level
plot_venn3_from_index_dict(ax_venn1[0, 1], set_to_count,
set_labels=[labels_rename.get(l,l) for l in labels],
set_colors=["orange", "gray", "red"])
bins = {
"pt_dq": np.linspace(90, 250, 50),
"pt_mc_t": np.linspace(0, 200, 50),
"pt_mc_t_dq_ratio": np.linspace(0, 1.3, 30),
"eta": np.linspace(-4, 4, 20),
"phi": np.linspace(-np.pi, np.pi, 20)
}
# 10 random colors
clrs = ["green", "red", "orange", "pink", "blue", "purple", "cyan", "magenta"]
key_rename_dict = {"pt_dq": "$p_T$ of quark", "pt_mc_t": "$p_T$ of particles within radius of R=0.8 of quark", "pt_mc_t_dq_ratio": "$p_T$ (part. within R=0.8 of quark) / $p_T$ (quark) ", "eta": "$\eta$ of quark", "phi": "$\phi$ of quark" }
for k, key in enumerate(["pt_dq", "pt_mc_t", "pt_mc_t_dq_ratio", "eta", "phi"]):
for s_idx, s in enumerate(sorted(set_to_stats.keys())):
if len(set_to_stats[s][key]) == 0:
continue
lbl = s
#if s == "":
# lbl = "none"
lbl1 = get_label_from_superset(lbl, labels_rename, labels)
if lbl1 not in ["Missed by all", "Found by both models but not AK", "AK8", "Found by all"]:
continue
if level == 1:
ax_venn1[k + 1, 1].hist(set_to_stats[s][key], bins=bins[key], histtype="step",
label=lbl1, color=clrs[s_idx], density=True)
ax_venn1[k + 1, 0].set_title(f"{key_rename_dict[key]}")
ax_venn1[k+1, 1].set_title(f"{key_rename_dict[key]}")
ax_venn1[k + 1, 1].set_ylabel("Density")
if lbl not in ["", "012"]:
# We are only interested in the differences...
ax_venn[k+1, level].hist(set_to_stats[s][key], bins=bins[key], histtype="step", label=lbl1, color=clrs[s_idx])
ax_venn[k+1, level].set_title(f"{key_rename_dict[key]}")
if level == 1:
ax_venn1[k + 1, 0].hist(set_to_stats[s][key], bins=bins[key], histtype="step",
label=lbl1,
color=clrs[s_idx])
#ax_venn[k+1, level].set_xlabel(key)
#ax_venn[k+1, level].set_ylabel("Count")
for k in range(5):
ax_venn[k+1, level].legend()
fig_venn.tight_layout()
for k in range(5):
ax_venn1[k+1, 0].legend()
ax_venn1[k+1, 1].legend()
fig_venn1.tight_layout()
f = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}.pdf")
fig_venn.savefig(f)
f1 = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}_reco_level_only.pdf")
fig_venn1.savefig(f1)
for i, lbl in enumerate(["precision", "recall", "F1"]): # 0=precision, 1=recall, 2=F1
sz_small1 = 2.5
fig, ax = plt.subplots(len(rename_results_dict), 3, figsize=(sz_small1 * 3, sz_small1 * len(rename_results_dict)))
for i1, key in enumerate(list(rename_results_dict.keys())):
for level in range(3):
level_text = text_level[level]
labels = list(results_dict[key][0].keys())
colors = [results_dict[key][0][l] for l in labels]
res_precision = np.array([to_plot_v2[level][hyp_m][hyp_rinv][l][0] for l in labels])
res_recall = np.array([to_plot_v2[level][hyp_m][hyp_rinv][l][1] for l in labels])
res_f1 = 2 * res_precision * res_recall / (res_precision + res_recall)
if i == 0:
values = res_precision
elif i == 1:
values = res_recall
else:
values = res_f1
rename_dict = results_dict[key][1]
labels_renamed = [rename_dict.get(l,l) for l in labels]
print(i1, level)
ax_tiny_histogram(ax[i1, level], labels_renamed, colors, values)
ax[i1, level].set_title(f"{rename_results_dict[key]} {level_text}")
fig.tight_layout()
fig.savefig(os.path.join(get_path(args.input, "results"), f"{lbl}_results_by_level_{hyp_m}_{hyp_rinv}_{key}.pdf"))
for hyp_m, hyp_rinv in hypotheses_to_plot:
if 0 not in to_plot_v2:
continue # Not for the lower-pt thresholds, where only GL and PL are available
if hyp_m not in to_plot_v2[0] or hyp_rinv not in to_plot_v2[0][hyp_m]:
continue
# plot here the venn diagrams
labels = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"]
labels_global = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"]
labels_rename = {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_50k": "900_03"}
fig_venn2, ax_venn2 = plt.subplots(1, len(labels), figsize=(4*len(labels), 4)) # the bottom ones are for pt of the DQ, pt of the MC GT, pt of MC GT / pt of DQ, eta, and phi distributions
for j, label in enumerate(labels):
#labels = list(results_dict["LGATr_comparison_GP_IRC_S_training"][0].keys())
label_combination_to_number = {} # fill it with all possible label combinations e.g. if there are 3 labels: "NA", "0", "1", "2", "01", "012", "12", "02"
powerset_str = ["".join([str(x) for x in sorted(list(a))]) for a in powerset(range(3))]
set_to_count = {key: 0 for key in powerset_str}
label_to_result = {}
n_dq = 99999999 # Sometimes, the last batch gets cut off etc. ...
for level in range(3):
r = flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label])
n_dq = min(n_dq, len(r))
for level in range(3):
r = torch.tensor(flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label]))
r = (r != -1)
label_to_result[level] = r.tolist()[:n_dq]
assert len(label_to_result[level]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[level])}"
for c in tqdm(range(n_dq)):
belonging_to_set = ""
for lvl in range(3):
if label_to_result[lvl][c] == 1:
belonging_to_set += str(lvl)
set_to_count[belonging_to_set] += 1
if hyp_m == 0 and hyp_rinv == 0:
title = f"QCD, {label} (missed by all: {set_to_count['']}) "
else:
title = f"$m_{{Z'}}={hyp_m}$ GeV, $r_{{inv.}}={hyp_rinv}$, {label} (miss: {set_to_count['']}) "
ax_venn2[j].set_title(title)
plot_venn3_from_index_dict(ax_venn2[j], set_to_count, set_labels=text_level, set_colors=["orange", "gray", "red"], remove_max=1)
fig_venn2.tight_layout()
f = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}_Agreement_between_levels.pdf")
fig_venn2.savefig(f)
for key in results_dict:
for level in range(3):
level_text = text_level[level]
labels = list(results_dict[key][0].keys())
if level in to_plot_v2:
f, a = multiple_matrix_plot(to_plot_v2[level], labels=labels, colors=[results_dict[key][0][l] for l in labels], rename_dict=results_dict[key][1])
if f is None:
print("No figure for", key, level)
continue
#f.suptitle(f"{level_text} $F_1$ score")
out_file = f"grid_stack_F1_{level_text}_{key}.pdf"
out_file = os.path.join(get_path(args.input, "results"), out_file)
f.savefig(out_file)
print("Saved to", out_file)
from matplotlib.lines import Line2D
# Define custom legend handles
custom_lines = [
Line2D([0], [0], color='orange', linestyle='-', label='LGATr'),
Line2D([0], [0], color='green', linestyle='-', label='GATr'),
Line2D([0], [0], color='blue', linestyle='-', label='Transformer'),
Line2D([0], [0], color='gray', linestyle='-', label='AK8'),
Line2D([0], [0], color='black', linestyle='-', label='reco'),
Line2D([0], [0], color='black', linestyle=':', label='gen'),
Line2D([0], [0], color='black', linestyle='--', label='parton'),
]
if len(models):
fig_steps, ax_steps = plt.subplots(len(m_Meds), len(r_invs), figsize=(sz_small * len(r_invs), sz_small * len(m_Meds)))
if len(m_Meds) == 1 and len(r_invs) == 1:
ax_steps = np.array([[ax_steps]])
histograms = {}
for key in histograms_dict:
if key not in histograms:
histograms[key] = {}
for i in ["pt", "eta", "phi"]:
f, a = plt.subplots(len(m_Meds), len(r_invs), figsize=(sz_small * len(r_invs), sz_small * len(m_Meds)))
if len(r_invs) == 1 and len(m_Meds) == 1:
a = np.array([[a]])
histograms[key][i] = f, a
colors = {"base_LGATr": "orange", "base_Tr": "blue", "base_GATr": "green", "AK8": "gray"} # THE COLORS FOR THE STEP VS. F1 SCORE
#colors_small_dataset = {"base_LGATr_SD": "orange", "base_Tr_SD": "blue", "base_GATr_SD": "green", "AK8": "gray"}
#colors = colors_small_dataset
level_styles = {"scouting": "solid", "PL": "dashed", "GL": "dotted"}
#step_to_plot_histograms = 50000 # phi, eta, pt histograms...
level_to_plot_histograms = "scouting"
for i, mMed_h in enumerate(m_Meds):
for j, rInv_h in enumerate(r_invs):
ax_steps[i, j].set_title("$m_{{Z'}} = {}$ GeV, $r_{{inv.}} = {}$".format(mMed_h, rInv_h))
ax_steps[i, j].set_xlabel("Training step")
ax_steps[i, j].set_ylabel("Test $F_1$ score")
#if j == 0:
#ax_steps[i, j].set_ylabel("$m_{{Z'}} = {}$".format(mMed_h))
#for subset in histograms:
#for key in histograms[subset]:
#histograms[subset][key][1][i, j].set_ylabel("$m_{{Z'}} = {}$".format(mMed_h))
if i == len(m_Meds)-1:
ax_steps[i, j].set_xlabel("$r_{{inv.}} = {}$".format(rInv_h))
for subset in histograms:
for key in histograms[subset]:
histograms[subset][key][1][i, j].set_xlabel("$r_{{inv.}} = {}$".format(rInv_h))
for model in jet_properties:
if level_to_plot_histograms not in jet_properties[model][mMed_h][rInv_h]:
print("Skipping", model, level_to_plot_histograms, " - levels:", jet_properties[model][mMed_h][rInv_h].keys())
continue
for subset in histograms:
for key in histograms[subset]:
if model not in histograms_dict[subset][1]:
continue
step_to_plot_histograms = histograms_dict[subset][0][model]
if step_to_plot_histograms not in jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms]:
print("Swapping the step to plot histograms", jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms].keys())
step_to_plot_histograms = sorted(list(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms].keys()))[0]
pred = np.array(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms][step_to_plot_histograms][key + "_pred"])
truth = np.array(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms][step_to_plot_histograms][key + "_gen_particle"])
if key.startswith("pt"):
q = pred/truth
symbol = "/" # division instead of subtraction symbol for pt
quantity = "p_{T,pred}/p_{T,true}"
bins = np.linspace(0, 2.5, 100)
elif key.startswith("eta"):
q = (pred - truth)
symbol = "-"
quantity="\eta_{pred}-\eta_{true}"
bins = np.linspace(-0.8, 0.8, 50)
elif key.startswith("phi"):
q = pred - truth
symbol = "-"
quantity = "\phi_{pred}-\phi_{true}"
q[q > np.pi] -= 2 * np.pi
q[q< -np.pi] += 2 * np.pi
bins = np.linspace(-0.8, 0.8, 50)
print("Max", np.max(q), "Min", np.min(q))
rename = {"base_LGATr": "LGATr",
"LGATr_GP_IRC_S_50k": "LGATr_GP_IRC_S",
"AK8": "AK8",
"LGATr_GP_50k": "LGATr_GP"}
histograms[subset][key][1][i, j].hist(q, histtype="step", color=histograms_dict[subset][1][model], label=rename.get(model, model), bins=bins, density=True)
if mMed_h > 0:
histograms[subset][key][1][i, j].set_title(f"${quantity}$ $m_{{Z'}}={mMed_h}$ GeV, $r_{{inv.}}={rInv_h}$")
else:
histograms[subset][key][1][i, j].set_title(f"${quantity}$")
histograms[subset][key][1][i, j].legend()
histograms[subset][key][1][i, j].grid(True)
for model in to_plot_steps:
for lvl in to_plot_steps[model][mMed_h][rInv_h]:
if model not in colors:
print("Skipping", model)
continue
print(model)
ls = level_styles[lvl]
plt_dict = to_plot_steps[model][mMed_h][rInv_h][lvl]
x_pts = sorted(list(plt_dict.keys()))
y_pts = [plt_dict[k] for k in x_pts]
if ls == "solid":
ax_steps[i, j].plot(x_pts, y_pts, label=model, marker=".", linestyle=ls, color=colors[model])
else:
# No label
ax_steps[i, j].plot(x_pts, y_pts, marker=".", linestyle=ls, color=colors[model])
ax_steps[i, j].legend(handles=custom_lines)
# now plot a horizontal line for the AKX same level
if lvl == "scouting":
rc = result_AKX_current
elif lvl == "PL":
rc = result_AKX_PL
elif lvl == "GL":
rc = result_AKX_GL
else:
raise Exception
pr = rc[mMed_h][mDark][rInv_h][0]
rec = rc[mMed_h][mDark][rInv_h][1]
f1ak = 2 * pr * rec / (pr + rec)
ax_steps[i, j].axhline(f1ak, color="gray", linestyle=ls, alpha=0.5)
ax_steps[i, j].grid(1)
path_steps_fig = os.path.join(get_path(args.input, "results"), "score_vs_step_plots.pdf")
fig_steps.tight_layout()
fig_steps.savefig(path_steps_fig)
for subset in histograms:
for key in histograms[subset]:
fig = histograms[subset][key][0]
fig.tight_layout()
fig.savefig(os.path.join(get_path(args.input, "results"), "histogram_{}_{}.pdf".format(key, subset)))
print("Saved to", path_steps_fig)
'''for i, h in enumerate(plotting_hypotheses):
mMed_h, rInv_h = h
if rInv_h not in to_plot[td]:
to_plot[td][rInv_h] = {}
print("Model", model)
if mMed_h not in to_plot[td][rInv_h]:
to_plot[td][rInv_h][mMed_h] = {} # level
if level not in to_plot[td][rInv_h][mMed_h]:
to_plot[td][rInv_h][mMed_h][level] = {"precision": [], "recall": [], "f1score": [], "R": []}
precision = result_PR[mMed_h][mDark][rInv_h][0]
recall = result_PR[mMed_h][mDark][rInv_h][1]
f1score = 2 * precision * recall / (precision + recall)
to_plot[td][rInv_h][mMed_h][level]["precision"].append(precision)
to_plot[td][rInv_h][mMed_h][level]["recall"].append(recall)
to_plot[td][rInv_h][mMed_h][level]["f1score"].append(f1score)
to_plot[td][rInv_h][mMed_h][level]["R"].append(r)
'''
to_plot_ak = {} # level ("scouting"/"GL"/"PL") -> rInv -> mMed -> {"f1score": [], "R": []}
for j, model in enumerate(["AKX", "AKX_PL", "AKX_GL"]):
print(model)
if os.path.exists(os.path.join(path, model, "count_matched_quarks", "result_PR_AKX.pkl")):
result_PR_AKX = pickle.load(open(os.path.join(path, model, "count_matched_quarks", "result_PR_AKX.pkl"), "rb"))
else:
print("Skipping", model)
continue
level = "scouting"
if "PL" in model:
level = "PL"
elif "GL" in model:
level = "GL"
if level not in to_plot_ak:
to_plot_ak[level] = {}
for mMed_h in result_PR_AKX:
if mMed_h not in results_all_ak:
results_all_ak[mMed_h] = {mDark: {}}
for rInv_h in result_PR_AKX[mMed_h][mDark]:
if rInv_h not in results_all_ak[mMed_h][mDark]:
results_all_ak[mMed_h][mDark][rInv_h] = {}
if level not in results_all_ak[mMed_h][mDark][rInv_h]:
results_all_ak[mMed_h][mDark][rInv_h][level] = {}
for ridx, R in enumerate(result_PR_AKX[mMed_h][mDark][rInv_h]):
if R not in results_all_ak[mMed_h][mDark][rInv_h][level]:
precision = result_PR_AKX[mMed_h][mDark][rInv_h][R][0]
recall = result_PR_AKX[mMed_h][mDark][rInv_h][R][1]
f1score = 2 * precision * recall / (precision + recall)
results_all_ak[mMed_h][mDark][rInv_h][level][R] = f1score
for i, h in enumerate(plotting_hypotheses):
mMed_h, rInv_h = h
if rInv_h not in to_plot_ak[level]:
to_plot_ak[level][rInv_h] = {}
print("Model", model)
if mMed_h not in to_plot_ak[level][rInv_h]:
to_plot_ak[level][rInv_h][mMed_h] = {"precision": [], "recall": [], "f1score": [], "R": []}
rs = sorted(result_PR_AKX[mMed_h][mDark][rInv_h].keys())
precision = np.array([result_PR_AKX[mMed_h][mDark][rInv_h][i][0] for i in rs])
recall = np.array([result_PR_AKX[mMed_h][mDark][rInv_h][i][1] for i in rs])
f1score = 2 * precision * recall / (precision + recall)
to_plot_ak[level][rInv_h][mMed_h]["precision"] = precision
to_plot_ak[level][rInv_h][mMed_h]["recall"] = recall
to_plot_ak[level][rInv_h][mMed_h]["f1score"] = f1score
to_plot_ak[level][rInv_h][mMed_h]["R"] = rs
print("AK:", to_plot_ak)
fig, ax = plt.subplots(len(to_plot) + 1, len(plotting_hypotheses), figsize=(sz_small * len(plotting_hypotheses), sz_small * len(to_plot))) # also add AKX as last plot
if len(to_plot) == 0:
ax = np.array([ax])
colors = {
#"PL": "green",
#"GL": "blue",
#"scouting": "red",
"PL+ghosts": "green",
"GL+ghosts": "blue",
"scouting+ghosts": "red"
}
ak_colors = {
"PL": "green",
"GL": "blue",
"scouting": "red",
}
'''
for i, td in enumerate(to_plot):
# for each training dataset
for j, h in enumerate(plotting_hypotheses):
ax[i, j].set_title(f"r_inv={h[1]}, m={h[0]}, tr. on {td}")
ax[i, j].set_ylabel("F1 score")
ax[i, j].set_xlabel("GT R")
ax[i, j].grid()
for level in sorted(list(to_plot[td][h[1]][h[0]].keys())):
print("level", level)
print("Plotting", td, h[1], h[0], level)
if level in colors:
ax[i, j].plot(to_plot[td][h[1]][h[0]][level]["R"], to_plot[td][h[1]][h[0]][level]["f1score"], ".-", label=level, color=colors[level])
ax[i, j].legend()
for j, h in enumerate(plotting_hypotheses): # for to_plot_AK
ax[-1, j].set_title(f"r_inv={h[1]}, m={h[0]}, AK baseline")
ax[-1, j].set_ylabel("F1 score")
ax[-1, j].set_xlabel("GT R")
ax[-1, j].grid()
for i, ak_level in enumerate(sorted(list(to_plot_ak.keys()))):
mMed_h, rInv_h = h
if ak_level in ak_colors:
ax[-1, j].plot(to_plot_ak[ak_level][rInv_h][mMed_h]["R"], to_plot_ak[ak_level][rInv_h][mMed_h]["f1score"], ".-", label=ak_level, color=ak_colors[ak_level])
ax[-1, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_1.pdf"))
print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_1.pdf"))
'''
fig, ax = plt.subplots(1, len(results_all)*len(radius) + len(radius), figsize=(7 * len(results_all)*len(radius)+len(radius), 5))
for i, model in enumerate(results_all):
for j, R in enumerate(radius):
#if r not in results_all[model][700][20][0.3]["scouting"]:
# continue
# for each training dataset
index = len(radius)*i + j
ax[index].set_title(model + " R={}".format(R))
matrix_plot(results_all[model], "Greens", r"PL/GL F1 score", ax=ax[index], metric_comp_func=lambda r: r["PL+ghosts"][R]/r["scouting+ghosts"][R])
for i, R in enumerate(radius):
index = len(radius)*len(results_all) + i
ax[index].set_title("AK R={}".format(R))
matrix_plot(results_all_ak, "Greens", r"PL/GL F1 score", ax=ax[index], metric_comp_func=lambda r: r["PL"][R]/r["GL"][R])
fig.tight_layout()
fig.savefig(out_file_PG)
print("Saved to", out_file_PG)
1/0
#print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf"))
#print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf"))
########### Now save the above plot with objectness score applied
if args.threshold_obj_score != -1:
fig, ax = plt.subplots(3, len(models), figsize=(sz * len(models), sz * 3))
for i, model in tqdm(enumerate(models)):
output_path = os.path.join(path, model, "count_matched_quarks")
if not os.path.exists(os.path.join(output_path, "result.pkl")):
continue
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
#result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb"))
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb"))
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb"))
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
result_PR_thresholds = pickle.load(open(os.path.join(output_path, "result_PR_thresholds.pkl"), "rb"))
#thresholds = sorted(list(result_PR_thresholds[900][20][0.3].keys()))
#thresholds = np.array(thresholds)
# now linearly interpolate the thresholds and set the j according to args.threshold_obj_score
j = np.argmin(np.abs(thresholds - args.threshold_obj_score))
print("Thresholds", thresholds)
print("Chose j=", j, "for threshold", args.threshold_obj_score, "(effectively it's", thresholds[j], ")")
def wrap(r):
# compute [precision, recall] array from [n_relevant_retrieved, all_retrieved, all_relevant]
if r[1] == 0 or r[2] == 0:
return [0, 0]
return [r[0] / r[1], r[0] / r[2]]
matrix_plot(result_PR_thresholds, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: wrap(r[j])[0], ax=ax[0, i])
matrix_plot(result_PR_thresholds, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: wrap(r[j])[1], ax=ax[1, i])
matrix_plot(result_PR_thresholds, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * wrap(r[j])[0] * wrap(r[j])[1] / (wrap(r[j])[0] + wrap(r[j])[1]), ax=ax[2, i])
ax[0, i].set_title(model)
ax[1, i].set_title(model)
ax[2, i].set_title(model)
fig.tight_layout()
fig.savefig(out_file_PR_OS)
print("Saved to", out_file_PR_OS)
################
# UNUSED PLOTS #
################
'''fig, ax = plt.subplots(2, len(models), figsize=(sz * len(models), sz * 2))
for i, model in tqdm(enumerate(models)):
output_path = os.path.join(path, model, "count_matched_quarks")
if not os.path.exists(os.path.join(output_path, "result.pkl")):
continue
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
#result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb"))
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb"))
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb"))
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
matrix_plot(result, "Blues", "Avg. matched dark quarks / event", ax=ax[0, i])
matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event", ax=ax[1, i])
ax[0, i].set_title(model)
ax[1, i].set_title(model)
fig.tight_layout()
fig.savefig(out_file_avg_number_matched_quarks)
print("Saved to", out_file_avg_number_matched_quarks)'''
rinvs = [0.3, 0.5, 0.7]
sz = 4
fig, ax = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs)))
fig_AK, ax_AK = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs)))
fig_AK_ratio, ax_AK_ratio = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs)))
to_plot = {} # r_inv -> m_med -> precision, recall, R
to_plot_ak = {} # plotting for the AK baseline
### Plotting the score vs GT R plots
oranges = plt.get_cmap("Oranges")
reds = plt.get_cmap("Reds") # Plot a plot for each mass at given r_inv of the precision, recall, F1 score
purples = plt.get_cmap("Purples")
mDark = 20
for i, rinv in enumerate(rinvs):
if rinv not in to_plot:
to_plot[rinv] = {}
to_plot_ak[rinv] = {}
for j, model in enumerate(models):
print("Model", model)
if model not in radius:
continue
r = radius[model]
output_path = os.path.join(path, model, "count_matched_quarks")
if not os.path.exists(os.path.join(output_path, "result_PR.pkl")):
continue
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
#if radius not in to_plot[rinv]:
# to_plot[rinv][radius] = {}
for k, mMed in enumerate(sorted(result_PR.keys())):
if mMed not in to_plot[rinv]:
to_plot[rinv][mMed] = {"precision": [], "recall": [], "f1score": [], "R": []}
precision = result_PR[mMed][mDark][rinv][0]
recall = result_PR[mMed][mDark][rinv][1]
f1score = 2 * precision * recall / (precision + recall)
to_plot[rinv][mMed]["precision"].append(precision)
to_plot[rinv][mMed]["recall"].append(recall)
to_plot[rinv][mMed]["f1score"].append(f1score)
to_plot[rinv][mMed]["R"].append(r)
for mMed in sorted(to_plot[rinv].keys()):
# normalize mmed between 0 and 1 (originally between 700 and 3000)
mmed = (mMed - 500) / (3000 - 500)
r = to_plot[rinv][mMed]
print("Model R", r["R"])
scatter_plot(ax[0, i], r["R"], r["precision"], label="m={} GeV".format(round(mMed)), color=oranges(mmed))
scatter_plot(ax[1, i], r["R"], r["recall"], label="m={} GeV".format(round(mMed)), color=reds(mmed))
scatter_plot(ax[2, i], r["R"], r["f1score"], label="m={} GeV".format(round(mMed)), color=purples(mmed))
if not os.path.exists(os.path.join(ak_path, "result_PR_AKX.pkl")):
continue
result_PR_AKX = pickle.load(open(os.path.join(ak_path, "result_PR_AKX.pkl"), "rb"))
result_jet_props_akx = pickle.load(open(os.path.join(ak_path, "result_jet_properties_AKX.pkl"), "rb"))
#if radius not in to_plot[rinv]:
# to_plot[rinv][radius] = {}
for k, mMed in enumerate(sorted(result_PR_AKX.keys())):
if mMed not in to_plot_ak[rinv]:
to_plot_ak[rinv][mMed] = {"precision": [], "recall": [], "f1score": [], "R": []}
rs = sorted(result_PR_AKX[mMed][mDark][rinv].keys())
precision = np.array([result_PR_AKX[mMed][mDark][rinv][k][0] for k in rs])
recall = np.array([result_PR_AKX[mMed][mDark][rinv][k][1] for k in rs])
f1score = 2 * precision * recall / (precision + recall)
to_plot_ak[rinv][mMed]["precision"] += list(precision)
to_plot_ak[rinv][mMed]["recall"] += list(recall)
to_plot_ak[rinv][mMed]["f1score"] += list(f1score)
to_plot_ak[rinv][mMed]["R"] += rs
for mMed in sorted(to_plot_ak[rinv].keys()):
# Normalize mmed between 0 and 1 (originally between 700 and 3000)
mmed = (mMed - 500) / (3000 - 500)
r = to_plot_ak[rinv][mMed]
r_model = to_plot[rinv][mMed]
print("AK R", r["R"])
scatter_plot(ax_AK[0, i], r["R"], r["precision"], label="m={} GeV AK".format(round(mMed)), color=oranges(mmed), pattern=".--")
scatter_plot(ax_AK[1, i], r["R"], r["recall"], label="m={} GeV AK".format(round(mMed)), color=reds(mmed), pattern=".--")
scatter_plot(ax_AK[2, i], r["R"], r["f1score"], label="m={} GeV AK".format(round(mMed)), color=purples(mmed), pattern=".--")
# r["R"] has more points than r_model["R"] - pick those from r["R"] that are in r_model["R"]
r["R"] = np.array(r["R"])
r["precision"] = np.array(r["precision"])
r["recall"] = np.array(r["recall"])
r["f1score"] = np.array(r["f1score"])
filt = np.isin(r["R"], r_model["R"])
r["R"] = r["R"][filt]
r["precision"] = r["precision"][filt]
r["recall"] = r["recall"][filt]
r["f1score"] = r["f1score"][filt]
scatter_plot(ax_AK_ratio[0, i], r["R"], r["precision"]/np.array(r_model["precision"]), label="m={} GeV AK".format(round(mMed)), color=oranges(mmed), pattern=".--")
scatter_plot(ax_AK_ratio[1, i], r["R"], r["recall"]/np.array(r_model["recall"]), label="m={} GeV AK".format(round(mMed)), color=reds(mmed), pattern=".--")
scatter_plot(ax_AK_ratio[2, i], r["R"], r["f1score"]/np.array(r_model["f1score"]), label="m={} GeV AK".format(round(mMed)), color=purples(mmed), pattern=".--")
for ax1 in [ax, ax_AK, ax_AK_ratio]:
ax1[0, i].set_title(f"Precision r_inv = {rinv}")
ax1[1, i].set_title(f"Recall r_inv = {rinv}")
ax1[2, i].set_title(f"F1 score r_inv = {rinv}")
ax1[2, i].legend()
ax1[1, i].legend()
ax1[0, i].legend()
ax1[0, i].grid()
ax1[1, i].grid()
ax1[2, i].grid()
ax1[0, i].set_xlabel("GT R")
ax1[1, i].set_xlabel("GT R")
ax1[2, i].set_xlabel("GT R")
ax_AK_ratio[0, i].set_ylabel("Precision (model=1)")
ax_AK_ratio[1, i].set_ylabel("Recall (model=1)")
ax_AK_ratio[2, i].set_ylabel("F1 score (model=1)")
fig.tight_layout()
fig_AK.tight_layout()
fig.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots.pdf"))
fig_AK.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf"))
fig_AK_ratio.tight_layout()
fig_AK_ratio.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf"))
print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf"))
print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf"))