import torch from itertools import chain, combinations import os from tqdm import tqdm import argparse import pickle from src.plotting.eval_matrix import matrix_plot, scatter_plot, multiple_matrix_plot, ax_tiny_histogram from src.utils.paths import get_path import matplotlib.pyplot as plt import numpy as np from collections import OrderedDict #### Plotting functions from matplotlib_venn import venn3 import matplotlib.pyplot as plt from copy import copy def plot_venn3_from_index_dict(ax, data_dict, set_labels=('Set 0', 'Set 1', 'Set 2'), set_colors=("orange", "purple", "gray"), remove_max=True): """ Generate a 3-set Venn diagram from a dictionary where keys are strings of '0', '1', and '2' indicating set membership, and values are counts. Parameters: - data_dict (dict): Dictionary with keys like '0', '01', '012', etc. - set_labels (tuple): Labels for the three sets. - remove_max: if true, it will remove """ # Mapping of set index combinations to venn3 region codes index_to_region = { '100': '100', # Only in Set 0 '010': '010', # Only in Set 1 '001': '001', # Only in Set 2 '110': '110', # In Set 0 and Set 1 '101': '101', # In Set 0 and Set 2 '011': '011', # In Set 1 and Set 2 '111': '111', # In all three } # Initialize region counts venn_counts = {region: 0 for region in index_to_region.values()} max_value = 0 for key in data_dict: if data_dict[key] > max_value and key != "": max_value = data_dict[key] print("Max val", max_value) data_dict = copy(data_dict) new_data_dict = {} for key in data_dict: if remove_max and data_dict[key] == max_value: # #data_dict[key] = 0 # del data_dict[key] continue else: new_data_dict[key] = data_dict[key] data_dict = new_data_dict print("data dict", data_dict) # Convert data_dict keys to binary keys for region mapping for k, v in data_dict.items(): binary_key = ''.join(['1' if str(i) in k else '0' for i in range(3)]) if binary_key in index_to_region: venn_counts[index_to_region[binary_key]] += v # Plotting #plt.figure(figsize=(8, 8)) del venn_counts['111'] venn = venn3(subsets=venn_counts, set_labels=set_labels, set_colors=set_colors, alpha=0.5, ax=ax) venn.get_label_by_id("111").set_text(max_value) #plt.title("3-Set Venn Diagram from Index Dictionary") #plt.show() ### Change this to make custom plots highlighting differences between different models (the histograms of pt_pred/pt_true, eta_pred-eta_true, and phi_pred-phi_true) histograms_dict = { "": [{"base_LGATr": 50000, "base_Tr": 50000 , "base_GATr": 50000, "AK8": 50000}, {"base_LGATr": "orange", "base_Tr": "blue", "base_GATr": "green", "AK8": "gray"}], "LGATr_comparison": [{"base_LGATr": 50000, "LGATr_GP_IRC_S_50k": 9960, "LGATr_GP_50k": 9960, "AK8": 50000, "LGATr_GP_IRC_SN_50k": 24000}, {"base_LGATr": "orange", "LGATr_GP_IRC_S_50k": "red", "LGATr_GP_50k": "purple", "LGATr_GP_IRC_SN_50k": "pink", "AK8": "gray"}], "LGATr_comparsion_DifferentTrainingDS": [{"base_LGATr": 50000, "LGATr_700_07": 50000, "LGATr_QCD": 50000, "LGATr_700_07+900_03": 50000, "LGATr_700_07+900_03+QCD": 50000, "AK8": 50000}, {"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", "LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}] } # This is a dictionary that contains the models and their colors for plotting - to plot the F1 scores etc. of the models results_dict = { "LGATr_comparison_DifferentTrainingDS": [{"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", "LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}, {"base_LGATr": "LGATr_900_03"}], # 2nd dict in list is rename dict "LGATr_comparison": [{"base_LGATr": "orange", "LGATr_GP_IRC_S_50k": "red", "LGATr_GP_50k": "purple", "LGATr_GP_IRC_SN_50k": "pink", "AK8": "gray"}, {"base_LGATr": "LGATr", "LGATr_GP_IRC_S_50k": "LGATr_GP_IRC_S", "LGATr_GP_50k": "LGATr_GP", "LGATr_GP_IRC_SN_50k": "LGATr_GP_IRC_SN"}], # 2nd dict in list is rename dict "LGATr_comparison_QCDtrain": [{"LGATr_QCD": "orange", "LGATr_GP_IRC_S_QCD": "red", "LGATr_GP_QCD": "purple", "LGATr_GP_IRC_SN_QCD": "pink", "AK8": "gray"}, {"LGATr_QCD": "LGATr", "LGATr_GP_IRC_S_QCD": "LGATr_GP_IRC_S", "LGATr_GP_QCD": "LGATr_GP", "LGATr_GP_IRC_SN_QCD": "LGATr_GP_IRC_SN"}], # 2nd dict in list is rename dict "LGATr_comparison_GP_training": [ {"LGATr_GP_QCD": "purple", "LGATr_GP_700_07": "red", "LGATr_GP_700_07+900_03": "blue", "LGATr_GP_700_07+900_03+QCD": "green", "LGATr_GP_50k": "orange", "AK8": "gray"}, {"LGATr_GP_QCD": "QCD", "LGATr_GP_700_07": "700_07", "LGATr_GP_700_07+900_03": "700_07+900_03" , "LGATr_GP_50k": "900_03", "LGATr_GP_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict ], "LGATr_comparison_GP_IRC_S_training": [ {"LGATr_GP_IRC_S_QCD": "purple", "LGATr_GP_IRC_S_700_07": "red", "LGATr_GP_IRC_S_700_07+900_03": "blue", "LGATr_GP_IRC_S_700_07+900_03+QCD": "green", "LGATr_GP_IRC_S_50k": "orange", "AK8": "gray"}, {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_700_07": "700_07", "LGATr_GP_IRC_S_700_07+900_03": "700_07+900_03", "LGATr_GP_IRC_S_50k": "900_03", "LGATr_GP_IRC_S_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict ], "LGATr_comparison_GP_IRC_SN_training": [ {"LGATr_GP_IRC_SN_QCD": "purple", "LGATr_GP_IRC_SN_700_07": "red", "LGATr_GP_IRC_SN_700_07+900_03": "blue", "LGATr_GP_IRC_SN_700_07+900_03+QCD": "green", "LGATr_GP_IRC_SN_50k": "orange", "AK8": "gray"}, {"LGATr_GP_IRC_SN_QCD": "QCD", "LGATr_GP_IRC_SN_700_07": "700_07", "LGATr_GP_IRC_SN_700_07+900_03": "700_07+900_03", "LGATr_GP_IRC_SN_50k": "900_03", "LGATr_GP_IRC_SN_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict ] } ''' "GP_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_19_21_29_06_946": "LGATr_GP_QCD", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_19_21_38_20_376": "LGATr_GP_700_07", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_13_12_54_359": "LGATr_GP_700_07+900_03+QCD", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_13_13_00_503": "LGATr_GP_700_07+900_03", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_15_29_30_29": "LGATr_GP_IRC_S_700_07+900_03", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_15_29_28_959": "LGATr_GP_IRC_S_700_07+900_03+QCD", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_20_15_11_35_476": "LGATr_GP_IRC_S_700_07", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_20_15_11_20_735": "LGATr_GP_IRC_S_QCD", ''' parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, required=False, default="scouting_PFNano_signals2/SVJ_hadronic_std/batch_eval/objectness_score") parser.add_argument("--threshold-obj-score", "-os-threshold", type=float, default=-1) thresholds = np.linspace(0.1, 1, 20) # also add 100 points between 0 and 0.1 at the beginning thresholds = np.concatenate([np.linspace(0, 0.1, 100), thresholds]) args = parser.parse_args() path = get_path(args.input, "results") models = sorted([x for x in os.listdir(path) if not os.path.isfile(os.path.join(path, x))]) models = [x for x in models if "AKX" not in x] figures_all = {} # title to the f1 score figure to plot figures_all_sorted = {} # model used -> step -> level -> f1 figure print("Models:", models) radius = { "LGATr_R10": 1.0, "LGATr_R09": 0.9, "LGATr_rinv_03_m_900": 0.8, "LGATr_R06": 0.6, "LGATr_R07": 0.7, "LGATr_R11": 1.1, "LGATr_R12": 1.2, "LGATr_R13": 1.3, "LGATr_R14": 1.4, "LGATr_R20": 2.0, "LGATr_R25": 2.5 } comments = { "Eval_params_study_2025_02_17_13_30_50": ", tr. on 07_700", "Eval_objectness_score_2025_02_12_15_34_33": ", tr. on 03_900, GT=all", "Eval_objectness_score_2025_02_18_08_48_13": ", tr. on 03_900, GT=closest", "Eval_objectness_score_2025_02_14_11_10_14": ", tr. on 03_900, GT=closest", "Eval_objectness_score_2025_02_21_14_51_07": ", tr. on 07_700", "Eval_objectness_score_2025_02_10_14_59_49": ", tr. on 03_900, GT=all", "Eval_objectness_score_2025_02_23_19_26_25": ", tr. on all, GT=closest", "Eval_objectness_score_2025_02_23_21_04_33": ", tr. on 03_900, GT=closest" } out_file_PR = os.path.join(get_path(args.input, "results"), "precision_recall.pdf") out_file_PRf1 = os.path.join(get_path(args.input, "results"), "f1_score_sorted.pdf") out_file_PG = os.path.join(get_path(args.input, "results"), "PLoverGL.pdf") if args.threshold_obj_score != -1: out_file_PR_OS = os.path.join(get_path(args.input, "results"), f"precision_recall_with_obj_score.pdf") out_file_avg_number_matched_quarks = os.path.join(get_path(args.input, "results"), "avg_number_matched_quarks.pdf") def get_plots_for_params(mMed, mDark, rInv, result_PR_thresholds): precisions = [] recalls = [] f1_scores = [] for i in range(len(thresholds)): if result_PR_thresholds[mMed][mDark][rInv][i][1] == 0: precisions.append(0) else: precisions.append( result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][1]) if result_PR_thresholds[mMed][mDark][rInv][i][2] == 0: recalls.append(0) else: recalls.append( result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][2]) for i in range(len(thresholds)): if precisions[i] + recalls[i] == 0: f1_scores.append(0) else: f1_scores.append(2 * precisions[i] * recalls[i] / (precisions[i] + recalls[i])) return precisions, recalls, f1_scores sz = 5 nplots = 9 # Now make 3 plots, one for mMed=700,r_inv=0.7; one for mMed=700,r_inv=0.5; one for mMed=700,r_inv=0.3 ###fig, ax = plt.subplots(3, 3, figsize=(3 * sz, 3 * sz)) '''fig, ax = plt.subplots(3, nplots, figsize=(nplots*sz, 3*sz)) for mi, mass in enumerate([700, 900, 1500]): start_idx = mi*3 for i0, rinv in enumerate([0.3, 0.5, 0.7]): i = start_idx + i0 # 0 is precision, 1 is recall, 2 is f1 score ax[0, i].set_title(f"r_inv={rinv}, m_med={mass} GeV") ax[1, i].set_title(f"r_inv={rinv}, m_med={mass} GeV") ax[2, i].set_title(f"r_inv={rinv}, m_med={mass} GeV") ax[0, i].set_ylabel("Precision") ax[1, i].set_ylabel("Recall") ax[2, i].set_ylabel("F1 score") ax[0, i].grid() ax[1, i].grid() ax[2, i].grid() ylims = {} # key: j and i default_ylims = [1, 0] for j, model in enumerate(models): result_PR_thresholds = os.path.join(path, model, "count_matched_quarks", "result_PR_thresholds.pkl") if not os.path.exists(result_PR_thresholds): continue run_config = pickle.load(open(os.path.join(path, model, "run_config.pkl"), "rb")) result_PR_thresholds = pickle.load(open(result_PR_thresholds, "rb")) if mass not in result_PR_thresholds: continue if rinv not in result_PR_thresholds[mass]: continue6 precisions, recalls, f1_scores = get_plots_for_params(mass, 20, rinv, result_PR_thresholds) if not run_config["gt_radius"] == 0.8: continue label = "R={} gl.f.={} {}".format(run_config["gt_radius"], run_config.get("global_features_obj_score", False), comments.get(run_config["run_name"], run_config["run_name"])) scatter_plot(ax[0, i], thresholds, precisions, label=label) scatter_plot(ax[1, i], thresholds, recalls, label=label) scatter_plot(ax[2, i], thresholds, f1_scores, label=label) #ylims[0] = [min(ylims[0][0], min(precisions)), max(ylims[0][1], max(precisions))] #ylims[1] = [min(ylims[1][0], min(recalls)), max(ylims[1][1], max(recalls))] #ylims[2] = [min(ylims[2][0], min(f1_scores)), max(ylims[2][1], max(f1_scores))] filt = thresholds < 0.2 precisions = np.array(precisions)[filt] recalls = np.array(recalls)[filt] f1_scores = np.array(f1_scores)[filt] if (i, 0) not in ylims: ylims[(i, 0)] = default_ylims upper_factor = 1.01 lower_factor = 0.99 ylims[(i, 0)] = [min(ylims[(i, 0)][0], min(precisions)*lower_factor), max(ylims[(i, 0)][1], max(precisions)*upper_factor)] if (i, 1) not in ylims: ylims[(i, 1)] = default_ylims ylims[(i, 1)] = [min(ylims[(i, 1)][0], min(recalls)*lower_factor), max(ylims[(i, 1)][1], max(recalls)*upper_factor)] if (i, 2) not in ylims: ylims[(i, 2)] = default_ylims ylims[(i, 2)] = [min(ylims[(i, 2)][0], min(f1_scores)*lower_factor), max(ylims[(i, 2)][1], max(f1_scores)*upper_factor)] for j in range(3): ax[j, i].set_ylim(ylims[(i, j)]) ax[j, i].legend() ax[j, i].set_xlim([0, 0.2]) ax[j, i].set_xlim([0, 0.2]) ax[j, i].set_xlim([0, 0.2]) # now adjust the ylim so that the plots are more readable fig.tight_layout() fig.savefig(os.path.join(get_path(args.input, "results"), "precision_recall_thresholds.pdf")) print("Saved to", os.path.join(get_path(args.input, "results"), "precision_recall_thresholds.pdf"))''' import wandb api = wandb.Api() def get_run_by_name(name): clust_suffix = "" if name.endswith("FT"): #remove FT from the end name = name[:-2] clust_suffix = "FT" if name.endswith("FT1"): #remove FT from the end # min-samples 1 min-cluster-size 2 epsilon 0.3 name = name[:-3] clust_suffix = "FT1" if name.endswith("10_5"): name = name[:-4] clust_suffix = "10_5" runs = api.runs( path="fcc_ml/svj_clustering", filters={"display_name": {"$eq": name.strip()}} ) runs = api.runs( path="fcc_ml/svj_clustering", filters={"display_name": {"$eq": name.strip()}} ) if runs.length != 1: return None return runs[0], clust_suffix def get_run_config(run_name): r, clust_suffix = get_run_by_name(run_name) if r is None: print("Getting info from run", run_name, "failed") return None, None config = r.config result = {} if config["parton_level"]: prefix = "PL" result["level"] = "PL" result["level_idx"] = 0 elif config["gen_level"]: prefix = "GL" result["level"] = "GL" result["level_idx"] = 2 else: prefix = "sc." result["level"] = "scouting" result["level_idx"] = 1 if config["augment_soft_particles"]: result["ghosts"] = True #result["level"] += "+ghosts" gt_r = config["gt_radius"] if config.get("augment_soft_particles", False): prefix += " (aug)" # ["LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_25_820", "LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_26_400"] training_datasets = { "LGATr_training_NoPID_10_16_64_0.8_AllData_2025_02_28_13_42_59": "all", "LGATr_training_NoPID_10_16_64_0.8_2025_02_28_12_42_59": "900_03", "LGATr_training_NoPID_10_16_64_2.0_2025_02_28_12_48_58": "900_03", "LGATr_training_NoPID_10_16_64_0.8_700_07_2025_02_28_13_01_59": "700_07", "LGATr_training_NoPIDGL_10_16_64_0.8_2025_03_17_20_05_04": "900_03_GenLevel", "LGATr_training_NoPIDGL_10_16_64_2.0_2025_03_17_20_05_04": "900_03_GenLevel", "Transformer_training_NoPID_10_16_64_2.0_2025_03_03_17_00_38": "900_03_T", "Transformer_training_NoPID_10_16_64_0.8_2025_03_03_15_55_50": "900_03_T", "LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_2025_03_27_12_46_12_740": "900_03+SoftAug", "LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_2025_03_28_10_43_36_81": "900_03+SoftAugVM", "LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_2025_03_28_10_43_37_44": "900_03+SoftAugVM", "LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_25_820": "900_03+qcap05", "LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_26_400": "900_03+qcap05", "LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_1e-2_2025_03_29_14_58_38_650": "pt 1e-2", "LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_1e-2_2025_03_29_14_58_36_446": "pt 1e-2", "LGATr_pt_1e-2_500part_2025_04_01_16_49_08_406": "500_pt_1e-2_PLFT", "LGATr_pt_1e-2_500part_2025_04_01_21_14_07_350": "500_pt_1e-2_PLFT", "LGATr_pt_1e-2_500part_NoQMin_2025_04_03_23_15_17_745": "500_1e-2_scFT", "LGATr_pt_1e-2_500part_NoQMin_2025_04_03_23_15_35_810": "500_1e-2_scFT", "LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_2025_04_04_12_57_51_536": "10_1000_1e-2_scFT", "LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_2025_04_04_12_57_47_788": "10_1000_1e-2_scFT", "LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_CW0_2025_04_04_15_30_16_839": "10_1000_1e-2_CW0", "LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_CW0_2025_04_04_15_30_20_113": "10_1000_1e-2_CW0", "debug_IRC_loss_weighted100_plus_ghosts_2025_04_08_22_40_33_972": "IRC_short_debug", "debug_IRC_loss_weighted100_plus_ghosts_2025_04_09_13_48_55_569": "IRC", "debug_IRC_loss_weighted100_plus_ghosts_Qmin05_2025_04_09_14_45_51_381": "IRC_qmin05", "LGATr_500part_NOQMin_2025_04_09_21_53_37_210": "500part_NOQMin_reprod", "IRC_loss_Split_and_Noise_alternate_Aug_2025_04_14_11_10_21_788": "IRC_Aug_S+N", "IRC_loss_Split_and_Noise_alternate_NoAug_2025_04_11_16_15_48_955": "IRC_S+N", "LGATr_training_NoPID_Delphes_10_16_64_0.8_2025_04_17_18_07_38_405": "DelphesTrain", "Delphes_IRC_aug_2025_04_19_11_16_17_130": "DelphesTrain+IRC", "LGATr_500part_NOQMin_Delphes_2025_04_19_11_15_24_417": "DelphesTrain+ghosts", "Delphes_IRC_aug_SplitOnly_2025_04_20_15_50_33_553": "DelphesTrain+IRC_SplitOnly", "Delphes_IRC_NOAug_SplitOnly_2025_04_21_12_58_36_99": "Delphes_IRC_NoAug_SplitOnly", "Delphes_IRC_NOAug_SplitAndNoise_2025_04_21_19_32_08_865": "Delphes_IRC_NoAug_S+N", "CONT_Delphes_IRC_aug_SplitOnly_2025_04_21_12_53_27_730": "IRC_aug_SplitOnly_ContFrom14k", "Transformer_training_NoPID_Delphes_PU_10_16_64_0.8_2025_05_03_18_37_01_188": "base_Tr_Old", "LGATr_training_NoPID_Delphes_PU_PFfix_10_16_64_0.8_2025_05_03_18_35_53_134": "base_LGATr", "GATr_training_NoPID_Delphes_PU_10_16_64_0.8_2025_05_03_18_35_48_163": "base_GATr_Old", "Transformer_training_NoPID_Delphes_PU_CoordFix_10_16_64_0.8_2025_05_05_13_05_20_755": "base_Tr", "GATr_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_05_16_24_13_579": "base_GATr_SD", "GATr_training_NoPID_Delphes_PU_CoordFix_10_16_64_0.8_2025_05_05_13_06_27_898": "base_GATr", "LGATr_Aug_2025_05_06_10_08_05_956": "LGATr_GP", "Delphes_Aug_IRCSplit_CONT_2025_05_07_11_00_18_422": "LGATr_GP_IRC_S", "Delphes_Aug_IRC_Split_and_Noise_2025_05_07_14_43_13_968": "LGATr_GP_IRC_SN", "Transformer_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_05_16_24_19_936": "base_Tr_SD", "LGATr_training_NoPID_Delphes_PU_PFfix_SmallDS_10_16_64_0.8_2025_05_05_16_24_16_127": "base_LGATr_SD", "Delphes_Aug_IRCSplit_2025_05_06_10_09_00_567": "LGATr_GP_IRC_S", "GATr_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_09_15_34_13_531": "base_GATr_SD", "Transformer_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_09_15_56_50_216": "base_Tr_SD", "LGATr_training_NoPID_Delphes_PU_PFfix_SmallDS_10_16_64_0.8_2025_05_09_15_56_50_875": "base_LGATr_SD", "Delphes_Aug_IRCSplit_50k_from10k_2025_05_11_14_08_49_675": "LGATr_GP_IRC_S_50k", "LGATr_Aug_50k_2025_05_09_15_25_32_34": "LGATr_GP_50k", "Delphes_Aug_IRCSplit_50k_2025_05_09_15_22_38_956": "LGATr_GP_IRC_S_50k", "LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_16_21_04_26_937": "LGATr_700_07+900_03+QCD", "LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_16_21_04_26_991": "LGATr_700_07+900_03", "LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_16_19_46_57_48": "LGATr_QCD", "LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_16_19_44_46_795": "LGATr_700_07", "Delphes_Aug_IRCSplit_50k_SN_from3kFT_2025_05_16_14_07_29_474": "LGATr_GP_IRC_SN_50k", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_19_21_29_06_946": "LGATr_GP_QCD", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_19_21_38_20_376": "LGATr_GP_700_07", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_13_12_54_359": "LGATr_GP_700_07+900_03+QCD", "GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_13_13_00_503": "LGATr_GP_700_07+900_03", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_15_29_30_29": "LGATr_GP_IRC_S_700_07+900_03", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_15_29_28_959": "LGATr_GP_IRC_S_700_07+900_03+QCD", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_20_15_11_35_476": "LGATr_GP_IRC_S_700_07", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_20_15_11_20_735": "LGATr_GP_IRC_S_QCD", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_24_23_00_54_948": "LGATr_GP_IRC_SN_QCD", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_24_23_00_56_910": "LGATr_GP_IRC_SN_700_07+900_03+QCD", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_24_23_01_01_212": "LGATr_GP_IRC_SN_700_07+900_03", "GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_24_23_01_07_703": "LGATr_GP_IRC_SN_700_07", } train_name = config["load_from_run"] ckpt_step = config["ckpt_step"] print("train name", train_name) if train_name not in training_datasets: print("!! unknown run", train_name) training_dataset = training_datasets.get(train_name, train_name) + "_s" + str(ckpt_step) + clust_suffix if "plptfilt01" in run_name.lower(): training_dataset += "_PLPtFiltMinPt01" # min pt 0.1 elif "noplfilter" in run_name.lower(): training_dataset += "_noPLFilter" elif "noplptfilter" in run_name.lower(): training_dataset += "_noPLPtFilter" # actually there was a 0.5 pt cut in the ntuplizer, removed by plptfilt01 elif "nopletafilter" in run_name.lower(): training_dataset += "_noPLEtaFilter" result["GT_R"] = gt_r result["training_dataset"] = training_dataset result["training_dataset_nostep"] = training_datasets.get(train_name, train_name) + clust_suffix result["ckpt_step"] = ckpt_step return f"GT_R={gt_r} {training_dataset}, {prefix}", result def flatten_list(lst):# lst is like [[0,0],[1,1]...] #return [item for sublist in lst for item in sublist] return list(chain.from_iterable(lst)) sz = 5 ak_path = os.path.join(path, "AKX", "count_matched_quarks") result_PR_AKX = pickle.load(open(os.path.join(ak_path, "result_PR_AKX.pkl"), "rb")) result_jet_props_akx = pickle.load(open(os.path.join(ak_path, "result_jet_properties_AKX.pkl"), "rb")) result_qj_akx = pickle.load(open(os.path.join(ak_path, "result_quark_to_jet.pkl"), "rb")) result_dq_pt_akx = pickle.load(open(os.path.join(ak_path, "result_pt_dq.pkl"), "rb")) result_dq_mc_pt_akx = pickle.load(open(os.path.join(ak_path, "result_pt_mc_gt.pkl"), "rb")) result_dq_props_akx = pickle.load(open(os.path.join(ak_path, "result_props_dq.pkl"), "rb")) try: result_PR_AKX_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_PR_AKX.pkl"), "rb")) result_qj_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_quark_to_jet.pkl"), "rb")) result_dq_mc_pt_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_pt_mc_gt.pkl"), "rb")) result_dq_pt_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_pt_dq.pkl"), "rb")) result_dq_props_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_props_dq.pkl"), "rb")) except FileNotFoundError: print("FileNotFoundError") result_PR_AKX_PL = result_PR_AKX try: result_PR_AKX_GL = pickle.load(open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_PR_AKX.pkl"), "rb")) result_qj_akx_GL = pickle.load(open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_quark_to_jet.pkl"), "rb")) result_dq_mc_pt_akx_GL = pickle.load( open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_pt_mc_gt.pkl"), "rb")) result_dq_pt_akx_GL = pickle.load( open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_pt_dq.pkl"), "rb")) result_dq_props_akx_GL = pickle.load( open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_props_dq.pkl"), "rb")) except FileNotFoundError: print("FileNotFoundError") result_PR_AKX_GL = result_PR_AKX #plot_only = ["LGATr_GP", "LGATr_GP_IRC_S", "LGATr_GP_IRC_SN", "LGATr_GP_50k", "LGATr_GP_IRC_S_50k"] plot_only = [] radius = [0.8] def select_radius(d, radius, depth=3): # from the dictionary, select radius at the level if depth == 0: return d[radius] return {key: select_radius(d[key], radius, depth - 1) for key in d} if len(models): # temporarily do not plot this one #fig, ax = plt.subplots(3, len(plot_only) + len(radius)*2, figsize=(sz * (len(plot_only)+len(radius)*2), sz * 3)) # three columns: PL, GL, scouting for each model for i, model in tqdm(enumerate(sorted(models))): output_path = os.path.join(path, model, "count_matched_quarks") if not os.path.exists(os.path.join(output_path, "result.pkl")): print("Result not exists for model", model) continue result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) #result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb")) #result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) #matrix_plot(result, "Blues", "Avg. matched dark quarks / event").savefig(os.path.join(output_path, "avg_matched_dark_quarks.pdf"), ax=ax[0, i]) #matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event").savefig(os.path.join(output_path, "avg_unmatched_jets.pdf"), ax=ax[1, i]) #matrix_plot(result_PR, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0], ax=ax[0, i]) #matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1], ax=ax[1, i]) #matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax[2, i]) print("Getting run config for model", model) run_config_title, run_config = get_run_config(model) print("RC title", run_config_title) if run_config is None: print("Skipping", model) continue #ax[0, i].set_title(run_config_title) #ax[1, i].set_title(run_config_title) #ax[2, i].set_title(run_config_title) li = run_config["level_idx"] #ax_f1[i, li].set_title(run_config_title) #matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax_f1[i, li]) figures_all[run_config_title] = result_PR print(model, run_config_title) td, gtr, level, tdns = run_config["training_dataset"], run_config["GT_R"], run_config["level_idx"], run_config["training_dataset_nostep"] if tdns in plot_only or not len(plot_only): td = "R=" + str(gtr) + " " + td if td not in figures_all_sorted: figures_all_sorted[td] = {} figures_all_sorted[td][level] = figures_all[run_config_title] result_AKX_current = select_radius(result_PR_AKX, 0.8) result_AKX_PL = select_radius(result_PR_AKX_PL, 0.8) result_AKX_GL = select_radius(result_PR_AKX_GL, 0.8) figures_all_sorted["AK8"]: { 0: result_AKX_PL, 1: result_AKX_current, 2: result_AKX_GL } for i, R in enumerate(radius): result_PR_AKX_current = select_radius(result_PR_AKX, R) #matrix_plot(result_PR_AKX_current, "Oranges", "Precision (N matched dark quarks / N predicted jets)", # metric_comp_func=lambda r: r[0], ax=ax[0, i+len(models)]) #matrix_plot(result_PR_AKX_current, "Reds", "Recall (N matched dark quarks / N dark quarks)", # metric_comp_func=lambda r: r[1], ax=ax[1, i+len(models)]) #matrix_plot(result_PR_AKX_current, "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), # ax=ax[2, i+len(models)]) #ax[0, i+len(models)].set_title(f"AK, R={R}") #ax[1, i+len(models)].set_title(f"AK, R={R}") #ax[2, i+len(models)].set_title(f"AK, R={R}") t = f"AK, R={R}" figures_all[t] = result_PR_AKX_current for i, R in enumerate(radius): result_PR_AKX_current = select_radius(result_PR_AKX_PL, R) #matrix_plot(result_PR_AKX_current, "Oranges", "Precision (N matched dark quarks / N predicted jets)", # metric_comp_func=lambda r: r[0], ax=ax[0, i+len(models)+len(radius)]) #matrix_plot(result_PR_AKX_current, "Reds", "Recall (N matched dark quarks / N dark quarks)", # metric_comp_func=lambda r: r[1], ax=ax[1, i+len(models)+len(radius)]) #matrix_plot(result_PR_AKX_current, "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), # ax=ax[2, i+len(models)+len(radius)]) #ax[0, i+len(models)+len(radius)].set_title(f"AK PL, R={R}") #ax[1, i+len(models)+len(radius)].set_title(f"AK PL, R={R}") #ax[2, i+len(models)+len(radius)].set_title(f"AK PL, R={R}") figures_all[f"AK PL, R={R}"] = result_PR_AKX_current for i, R in enumerate(radius): result_PR_AKX_current = select_radius(result_PR_AKX_GL, R) figures_all[f"AK GL, R={R}"] = result_PR_AKX_current #fig.tight_layout() #fig.savefig(out_file_PR) #print("Saved to", out_file_PR) #fig_f1.tight_layout().463 #fig_f1.savefig(out_file_PRf1) pickle.dump(figures_all, open(out_file_PR.replace(".pdf", ".pkl"), "wb")) figures_all_sorted["AK8"] = { 0: select_radius(result_PR_AKX_PL, 0.8), 1: select_radius(result_PR_AKX, 0.8), 2: select_radius(result_PR_AKX_GL, 0.8) } text_level = ["PL", "PFCands", "GL"] fig_f1, ax_f1 = plt.subplots(len(figures_all_sorted), 3, figsize=(sz * 2.5, sz * len(figures_all_sorted))) if len(figures_all_sorted) == 1: ax_f1 = np.array([ax_f1]) for i in range(len(figures_all_sorted)): model = list(figures_all_sorted.keys())[i] renames = { "R=0.8 base_LGATr_s50000": "LGATr", "R=0.8 LGATr_GP_50k_s25020": "LGATr_GP", "R=0.8 LGATr_GP_IRC_S_50k_s12900": "LGATr_GP_IRC_S", "AK8": "AK8", "R=0.8 LGATr_GP_IRC_SN_50k_s22020": "LGATr_GP_IRC_SN" } for j in range(3): if j in figures_all_sorted[model]: if j in figures_all_sorted[model]: matrix_plot(figures_all_sorted[model][j], "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax_f1[i, j], is_qcd="qcd" in path.lower()) ax_f1[i, j].set_title(renames.get(model, model) + " "+ text_level[j]) ax_f1[i, j].set_xlabel("$m_{Z'}$") ax_f1[i, j].set_ylabel("$r_{inv.}$") fig_f1.tight_layout() fig_f1.savefig(out_file_PRf1) import pandas as pd # plot QCD results: def get_qcd_results(i): # i=0: precision, i=1: recall, i=2: f1 score qcd_results = {} for model in figures_all_sorted: qcd_results[model] = {} for level in figures_all_sorted[model]: r = figures_all_sorted[model][level][0][0][0] r = [float(x) for x in r] # append f1 score r.append(r[0]*2*r[1] / (r[0]+r[1])) qcd_results[model][text_level[level]] = r[i] return pd.DataFrame(qcd_results).T if "qcd" in path.lower(): print("Precision:") print(get_qcd_results(0)) print("----------------") print("Recall:") print(get_qcd_results(1)) print("----------------") print("F1 score:") print(get_qcd_results(2)) ## Now do the GT R vs metrics plots oranges = plt.get_cmap("Oranges") reds = plt.get_cmap("Reds") purples = plt.get_cmap("Purples") mDark = 20 if "qcd" in path.lower(): print("QCD events") mDark=0 to_plot = {} # training dataset -> rInv -> mMed -> level -> "f1score" -> value to_plot_steps = {} # training dataset -> rInv -> mMed -> level -> step -> value to_plot_v2 = {} # level -> rInv -> mMed -> {"model": [P,R]} quark_to_jet = {} # level -> rInv -> mMed -> model -> quark to jet assignment list mc_gt_pt_of_dq = {} pt_of_dq = {} props_of_dq = {"eta": {}, "phi": {}} # Properties of dark quarks: eta and phi results_all = {} results_all_ak = {} jet_properties = {} # training dataset -> rInv -> mMed -> level -> step -> jet property dict jet_properties_ak = {} # rInv -> mMed -> level -> radius plotting_hypotheses = [[700, 0.7], [700, 0.5], [700, 0.3], [900, 0.3], [900, 0.7]] if "qcd" in path.lower(): plotting_hypotheses = [[0,0]] sz_small = 5 for j, model in enumerate(models): _, rc = get_run_config(model) if rc is None or model in ["Eval_eval_19March2025_pt1e-2_500particles_FT_PL_2025_04_02_14_28_33_421FT", "Eval_eval_19March2025_pt1e-2_500particles_FT_PL_2025_04_02_14_47_23_671FT", "Eval_eval_19March2025_small_aug_vanishing_momentum_2025_03_28_11_45_16_582", "Eval_eval_19March2025_small_aug_vanishing_momentum_2025_03_28_11_46_26_326"]: print("Skipping", model) continue td = rc["training_dataset"] td_raw = rc["training_dataset_nostep"] level = rc["level"] r = rc["GT_R"] output_path = os.path.join(path, model, "count_matched_quarks") if not os.path.exists(os.path.join(output_path, "result_PR.pkl")): continue result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) result_QJ = pickle.load(open(os.path.join(output_path, "result_quark_to_jet.pkl"), "rb")) result_jet_props = pickle.load(open(os.path.join(output_path, "result_jet_properties.pkl"), "rb")) result_MC_PT = pickle.load(open(os.path.join(output_path, "result_pt_mc_gt.pkl"), "rb")) result_PT_DQ = pickle.load(open(os.path.join(output_path, "result_pt_dq.pkl"), "rb")) result_DQ_props = pickle.load(open(os.path.join(output_path, "result_props_dq.pkl"), "rb")) print(level) if td not in to_plot: to_plot[td] = {} results_all[td] = {} if td_raw not in to_plot_steps: to_plot_steps[td_raw] = {} jet_properties[td_raw] = {} level_idx = ["PL", "scouting", "GL"].index(level) if level_idx not in to_plot_v2: to_plot_v2[level_idx] = {} quark_to_jet[level_idx] = {} pt_of_dq[level_idx] = {} mc_gt_pt_of_dq[level_idx] = {} for prop in props_of_dq: props_of_dq[prop][level_idx] = {} for mMed_h in result_PR: if mMed_h not in to_plot_v2[level_idx]: to_plot_v2[level_idx][mMed_h] = {} quark_to_jet[level_idx][mMed_h] = {} pt_of_dq[level_idx][mMed_h] = {} mc_gt_pt_of_dq[level_idx][mMed_h] = {} for prop in props_of_dq: props_of_dq[prop][level_idx][mMed_h] = {} if mMed_h not in to_plot_steps[td_raw]: to_plot_steps[td_raw][mMed_h] = {} jet_properties[td_raw][mMed_h] = {} if mMed_h not in results_all[td]: results_all[td][mMed_h] = {mDark: {}} for rInv_h in result_PR[mMed_h][mDark]: if rInv_h not in to_plot_v2[level_idx][mMed_h]: to_plot_v2[level_idx][mMed_h][rInv_h] = {} quark_to_jet[level_idx][mMed_h][rInv_h] = {} pt_of_dq[level_idx][mMed_h][rInv_h] = {} mc_gt_pt_of_dq[level_idx][mMed_h][rInv_h] = {} for prop in props_of_dq: props_of_dq[prop][level_idx][mMed_h][rInv_h] = {} if rInv_h not in to_plot_steps[td_raw][mMed_h]: to_plot_steps[td_raw][mMed_h][rInv_h] = {} jet_properties[td_raw][mMed_h][rInv_h] = {} if level not in to_plot_steps[td_raw][mMed_h][rInv_h]: to_plot_steps[td_raw][mMed_h][rInv_h][level] = {} jet_properties[td_raw][mMed_h][rInv_h][level] = {} if rInv_h not in results_all[td][mMed_h][mDark]: results_all[td][mMed_h][mDark][rInv_h] = {} #for level in ["PL+ghosts", "GL+ghosts", "scouting+ghosts"]: if level not in results_all[td][mMed_h][mDark][rInv_h]: results_all[td][mMed_h][mDark][rInv_h][level] = {} precision = result_PR[mMed_h][mDark][rInv_h][0] recall = result_PR[mMed_h][mDark][rInv_h][1] f1score = 2 * precision * recall / (precision + recall) to_plot_v2[level_idx][mMed_h][rInv_h][td_raw] = [precision, recall] quark_to_jet[level_idx][mMed_h][rInv_h][td_raw] = result_QJ[mMed_h][mDark][rInv_h] pt_of_dq[level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_PT_DQ[mMed_h][mDark][rInv_h]) mc_gt_pt_of_dq[level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_MC_PT[mMed_h][mDark][rInv_h]) for prop in props_of_dq: props_of_dq[prop][level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_DQ_props[prop][mMed_h][mDark][rInv_h]) #print("qj", quark_to_jet[level_idx][mMed_h][rInv_h][td_raw]) if r not in results_all[td][mMed_h][mDark][rInv_h][level]: results_all[td][mMed_h][mDark][rInv_h][level][r] = f1score ckpt_step = rc["ckpt_step"] to_plot_steps[td_raw][mMed_h][rInv_h][level][ckpt_step] = f1score jet_properties[td_raw][mMed_h][rInv_h][level][ckpt_step] = result_jet_props[mMed_h][mDark][rInv_h] m_Meds = [] r_invs = [] for key in to_plot_steps: m_Meds += list(to_plot_steps[key].keys()) for key2 in to_plot_steps[key]: r_invs += list(to_plot_steps[key][key2].keys()) m_Meds = sorted(list(set(m_Meds))) r_invs = sorted(list(set(r_invs))) result_AKX_current = select_radius(result_PR_AKX, 0.8) result_AKX_PL = select_radius(result_PR_AKX_PL, 0.8) result_AKX_GL = select_radius(result_PR_AKX_GL, 0.8) result_AKX_jet_properties = select_radius(result_jet_props_akx, 0.8) jet_properties["AK8"] = {} result_AKX_current_QJ = select_radius(result_qj_akx, 0.8) result_AKX_PL_QJ = select_radius(result_qj_akx_PL, 0.8) result_AKX_GL_QJ = select_radius(result_qj_akx_GL, 0.8) result_AKX_current_pt_dq = select_radius(result_dq_pt_akx, 0.8) result_AKX_PL_pt_dq = select_radius(result_dq_pt_akx_PL, 0.8) result_AKX_GL_pt_dq = select_radius(result_dq_pt_akx_GL, 0.8) result_AKX_current_MCpt_dq = select_radius(result_dq_mc_pt_akx, 0.8) result_AKX_PL_MCpt_dq = select_radius(result_dq_mc_pt_akx_PL, 0.8) result_AKX_GL_MCpt_dq = select_radius(result_dq_mc_pt_akx_GL, 0.8) result_AKX_current_props_dq = select_radius(result_dq_props_akx, 0.8, depth=4) result_AKX_PL_props_dq = select_radius(result_dq_props_akx_PL, 0.8, depth=4) result_AKX_GL_props_dq = select_radius(result_dq_props_akx_GL, 0.8, depth=4) from tqdm import tqdm for mMed_h in result_AKX_jet_properties: for rInv_h in result_AKX_jet_properties[mMed_h][mDark]: if 0 in to_plot_v2: to_plot_v2[0][mMed_h][rInv_h]["AK8"] = result_AKX_PL[mMed_h][mDark][rInv_h] to_plot_v2[1][mMed_h][rInv_h]["AK8"] = result_AKX_current[mMed_h][mDark][rInv_h] to_plot_v2[2][mMed_h][rInv_h]["AK8"] = result_AKX_GL[mMed_h][mDark][rInv_h] quark_to_jet[0][mMed_h][rInv_h]["AK8"] = result_AKX_PL_QJ[mMed_h][mDark][rInv_h] quark_to_jet[1][mMed_h][rInv_h]["AK8"] = result_AKX_current_QJ[mMed_h][mDark][rInv_h] quark_to_jet[2][mMed_h][rInv_h]["AK8"] = result_AKX_GL_QJ[mMed_h][mDark][rInv_h] pt_of_dq[0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_pt_dq[mMed_h][mDark][rInv_h]) pt_of_dq[1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_pt_dq[mMed_h][mDark][rInv_h]) pt_of_dq[2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_pt_dq[mMed_h][mDark][rInv_h]) mc_gt_pt_of_dq[0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_MCpt_dq[mMed_h][mDark][rInv_h]) mc_gt_pt_of_dq[1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_MCpt_dq[mMed_h][mDark][rInv_h]) mc_gt_pt_of_dq[2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_MCpt_dq[mMed_h][mDark][rInv_h]) for k in props_of_dq: props_of_dq[k][0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_props_dq[k][mMed_h][mDark][rInv_h]) props_of_dq[k][1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_props_dq[k][mMed_h][mDark][rInv_h]) props_of_dq[k][2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_props_dq[k][mMed_h][mDark][rInv_h]) if mMed_h not in jet_properties["AK8"]: jet_properties["AK8"][mMed_h] = {} if rInv_h not in jet_properties["AK8"][mMed_h]: jet_properties["AK8"][mMed_h][rInv_h] = {} jet_properties["AK8"][mMed_h][rInv_h] = {"scouting": {50000: result_AKX_jet_properties[mMed_h][mDark][rInv_h]}} rename_results_dict = { "LGATr_comparison_DifferentTrainingDS": "base", "LGATr_comparison_GP_training": "GP", "LGATr_comparison_GP_IRC_S_training": "GP_IRC_S", "LGATr_comparison_GP_IRC_SN_training": "GP_IRC_SN" } hypotheses_to_plot = [[0,0],[700,0.7],[700,0.5],[700,0.3]] def powerset(iterable): "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" s = list(iterable) return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) def get_label_from_superset(lbl, labels_rename, labels): if lbl == '': return "Missed by all" r = [labels[int(i)] for i in lbl] r = [labels_rename.get(l,l) for l in r] if len(r) == 2 and "QCD" in r and "900_03" in r: return "Found by both models but not AK" if len(r) == 3: return "Found by all" return ", ".join(r) for hyp_m, hyp_rinv in hypotheses_to_plot: if 0 not in to_plot_v2: continue # Not for the lower-pt thresholds, where only GL and PL are available if hyp_m not in to_plot_v2[0] or hyp_rinv not in to_plot_v2[0][hyp_m]: continue # plot here the venn diagrams labels = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] labels_global = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] labels_rename = {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_50k": "900_03"} fig_venn, ax_venn = plt.subplots(6, 3, figsize=(5 * 3, 5 * 6)) # the bottom ones are for pt of the DQ, pt of the MC GT, pt of MC GT / pt of DQ, eta, and phi distributions fig_venn1, ax_venn1 = plt.subplots(6, 2, figsize=(5*2, 5*6)) # Only the PFCands-level, with full histogram on the left and density on the right for level in range(3): #labels = list(results_dict["LGATr_comparison_GP_IRC_S_training"][0].keys()) label_combination_to_number = {} # fill it with all possible label combinations e.g. if there are 3 labels: "NA", "0", "1", "2", "01", "012", "12", "02" powerset_str = ["".join([str(x) for x in sorted(list(a))]) for a in powerset(range(len(labels)))] set_to_count = {key: 0 for key in powerset_str} set_to_stats = {key: {"pt_dq": [], "pt_mc_t": [], "pt_mc_t_dq_ratio": [], "eta": [], "phi": []} for key in powerset_str} label_to_result = {} #label_to_stats = {"pt_dq": , "pt_mc_t": [], "pt_mc_t_dq_ratio": [], "eta": [], "phi": []} n_dq = 999999999 for j, label in enumerate(labels): r = flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label]) n_dq = min(n_dq, len(r)) # Find the minimum number of dark quarks in all labels for j, label in enumerate(labels): r = torch.tensor(flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label])) r = (r != -1) # Whether quark no. X is caught or not label_to_result[j] = r.tolist()[:n_dq] #r = torch.tensor(flatten_list(pt_of_dq[level][hyp_m][hyp_rinv][label])) #r = r[:n_dq] #label_to_stats["pt_dq"].append(r.tolist()) #r1 = torch.tensor(flatten_list(mc_gt_pt_of_dq[level][hyp_m][hyp_rinv][label])) #r1 = r1[:n_dq] #label_to_stats["pt_mc_t"].append(r1.tolist()) #r2 = r1 / r #r2 = r2[:n_dq] #label_to_stats["pt_mc_t_dq_ratio"].append(r2.tolist()) #r_eta = torch.tensor(flatten_list(props_of_dq["eta"][level][hyp_m][hyp_rinv][label])) #r_eta = r_eta[:n_dq] #label_to_stats["eta"].append(r_eta.tolist()) ##r_phi = torch.tensor(flatten_list(props_of_dq["phi"][level][hyp_m][hyp_rinv][label])) #r_phi = r_phi[:n_dq] #label_to_stats["phi"].append(r_phi.tolist()) assert len(label_to_result[j]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[j])}" #n_dq = min(n_dq, len(r)) #for j, label in enumerate(labels): # assert len(label_to_result[j]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[j])}" for c in tqdm(range(n_dq)): belonging_to_set = "" for j, label in enumerate(labels): if label_to_result[j][c] == 1: belonging_to_set += str(j) set_to_count[belonging_to_set] += 1 #for key in label_to_stats: # for idx in belonging_to_set: # idx_int = int(idx) # e.g. "0", "1" etc. # set_to_stats[belonging_to_set] for j, label in enumerate(labels): current_dq_pt = pt_of_dq[level][hyp_m][hyp_rinv][label][c] current_mc_gt_pt = mc_gt_pt_of_dq[level][hyp_m][hyp_rinv][label][c] current_dq_eta = props_of_dq["eta"][level][hyp_m][hyp_rinv][label][c] current_dq_phi = props_of_dq["phi"][level][hyp_m][hyp_rinv][label][c] set_to_stats[belonging_to_set]["pt_dq"].append(current_dq_pt) set_to_stats[belonging_to_set]["pt_mc_t"].append(current_mc_gt_pt) set_to_stats[belonging_to_set]["pt_mc_t_dq_ratio"].append(current_mc_gt_pt/current_dq_pt) set_to_stats[belonging_to_set]["eta"].append(current_dq_eta) set_to_stats[belonging_to_set]["phi"].append(current_dq_phi) #print("set_to_count for level", level, ":", set_to_count, "labels:", labels) title = f"$m_{{Z'}}={hyp_m}$ GeV, $r_{{inv.}}={hyp_rinv}$, {text_level[level]} (missed by all: {set_to_count['']}) " if hyp_m == 0 and hyp_rinv == 0: title = f"QCD, {text_level[level]} (missed by all: {set_to_count['']})" ax_venn[0, level].set_title(title) plot_venn3_from_index_dict(ax_venn[0, level], set_to_count, set_labels=[labels_rename.get(l,l) for l in labels], set_colors=["orange", "gray", "red"]) if level == 1: #reco-level plot_venn3_from_index_dict(ax_venn1[0, 1], set_to_count, set_labels=[labels_rename.get(l,l) for l in labels], set_colors=["orange", "gray", "red"]) bins = { "pt_dq": np.linspace(90, 250, 50), "pt_mc_t": np.linspace(0, 200, 50), "pt_mc_t_dq_ratio": np.linspace(0, 1.3, 30), "eta": np.linspace(-4, 4, 20), "phi": np.linspace(-np.pi, np.pi, 20) } # 10 random colors clrs = ["green", "red", "orange", "pink", "blue", "purple", "cyan", "magenta"] key_rename_dict = {"pt_dq": "$p_T$ of quark", "pt_mc_t": "$p_T$ of particles within radius of R=0.8 of quark", "pt_mc_t_dq_ratio": "$p_T$ (part. within R=0.8 of quark) / $p_T$ (quark) ", "eta": "$\eta$ of quark", "phi": "$\phi$ of quark" } for k, key in enumerate(["pt_dq", "pt_mc_t", "pt_mc_t_dq_ratio", "eta", "phi"]): for s_idx, s in enumerate(sorted(set_to_stats.keys())): if len(set_to_stats[s][key]) == 0: continue lbl = s #if s == "": # lbl = "none" lbl1 = get_label_from_superset(lbl, labels_rename, labels) if lbl1 not in ["Missed by all", "Found by both models but not AK", "AK8", "Found by all"]: continue if level == 1: ax_venn1[k + 1, 1].hist(set_to_stats[s][key], bins=bins[key], histtype="step", label=lbl1, color=clrs[s_idx], density=True) ax_venn1[k + 1, 0].set_title(f"{key_rename_dict[key]}") ax_venn1[k+1, 1].set_title(f"{key_rename_dict[key]}") ax_venn1[k + 1, 1].set_ylabel("Density") if lbl not in ["", "012"]: # We are only interested in the differences... ax_venn[k+1, level].hist(set_to_stats[s][key], bins=bins[key], histtype="step", label=lbl1, color=clrs[s_idx]) ax_venn[k+1, level].set_title(f"{key_rename_dict[key]}") if level == 1: ax_venn1[k + 1, 0].hist(set_to_stats[s][key], bins=bins[key], histtype="step", label=lbl1, color=clrs[s_idx]) #ax_venn[k+1, level].set_xlabel(key) #ax_venn[k+1, level].set_ylabel("Count") for k in range(5): ax_venn[k+1, level].legend() fig_venn.tight_layout() for k in range(5): ax_venn1[k+1, 0].legend() ax_venn1[k+1, 1].legend() fig_venn1.tight_layout() f = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}.pdf") fig_venn.savefig(f) f1 = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}_reco_level_only.pdf") fig_venn1.savefig(f1) for i, lbl in enumerate(["precision", "recall", "F1"]): # 0=precision, 1=recall, 2=F1 sz_small1 = 2.5 fig, ax = plt.subplots(len(rename_results_dict), 3, figsize=(sz_small1 * 3, sz_small1 * len(rename_results_dict))) for i1, key in enumerate(list(rename_results_dict.keys())): for level in range(3): level_text = text_level[level] labels = list(results_dict[key][0].keys()) colors = [results_dict[key][0][l] for l in labels] res_precision = np.array([to_plot_v2[level][hyp_m][hyp_rinv][l][0] for l in labels]) res_recall = np.array([to_plot_v2[level][hyp_m][hyp_rinv][l][1] for l in labels]) res_f1 = 2 * res_precision * res_recall / (res_precision + res_recall) if i == 0: values = res_precision elif i == 1: values = res_recall else: values = res_f1 rename_dict = results_dict[key][1] labels_renamed = [rename_dict.get(l,l) for l in labels] print(i1, level) ax_tiny_histogram(ax[i1, level], labels_renamed, colors, values) ax[i1, level].set_title(f"{rename_results_dict[key]} {level_text}") fig.tight_layout() fig.savefig(os.path.join(get_path(args.input, "results"), f"{lbl}_results_by_level_{hyp_m}_{hyp_rinv}_{key}.pdf")) for hyp_m, hyp_rinv in hypotheses_to_plot: if 0 not in to_plot_v2: continue # Not for the lower-pt thresholds, where only GL and PL are available if hyp_m not in to_plot_v2[0] or hyp_rinv not in to_plot_v2[0][hyp_m]: continue # plot here the venn diagrams labels = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] labels_global = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] labels_rename = {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_50k": "900_03"} fig_venn2, ax_venn2 = plt.subplots(1, len(labels), figsize=(4*len(labels), 4)) # the bottom ones are for pt of the DQ, pt of the MC GT, pt of MC GT / pt of DQ, eta, and phi distributions for j, label in enumerate(labels): #labels = list(results_dict["LGATr_comparison_GP_IRC_S_training"][0].keys()) label_combination_to_number = {} # fill it with all possible label combinations e.g. if there are 3 labels: "NA", "0", "1", "2", "01", "012", "12", "02" powerset_str = ["".join([str(x) for x in sorted(list(a))]) for a in powerset(range(3))] set_to_count = {key: 0 for key in powerset_str} label_to_result = {} n_dq = 99999999 # Sometimes, the last batch gets cut off etc. ... for level in range(3): r = flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label]) n_dq = min(n_dq, len(r)) for level in range(3): r = torch.tensor(flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label])) r = (r != -1) label_to_result[level] = r.tolist()[:n_dq] assert len(label_to_result[level]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[level])}" for c in tqdm(range(n_dq)): belonging_to_set = "" for lvl in range(3): if label_to_result[lvl][c] == 1: belonging_to_set += str(lvl) set_to_count[belonging_to_set] += 1 if hyp_m == 0 and hyp_rinv == 0: title = f"QCD, {label} (missed by all: {set_to_count['']}) " else: title = f"$m_{{Z'}}={hyp_m}$ GeV, $r_{{inv.}}={hyp_rinv}$, {label} (miss: {set_to_count['']}) " ax_venn2[j].set_title(title) plot_venn3_from_index_dict(ax_venn2[j], set_to_count, set_labels=text_level, set_colors=["orange", "gray", "red"], remove_max=1) fig_venn2.tight_layout() f = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}_Agreement_between_levels.pdf") fig_venn2.savefig(f) for key in results_dict: for level in range(3): level_text = text_level[level] labels = list(results_dict[key][0].keys()) if level in to_plot_v2: f, a = multiple_matrix_plot(to_plot_v2[level], labels=labels, colors=[results_dict[key][0][l] for l in labels], rename_dict=results_dict[key][1]) if f is None: print("No figure for", key, level) continue #f.suptitle(f"{level_text} $F_1$ score") out_file = f"grid_stack_F1_{level_text}_{key}.pdf" out_file = os.path.join(get_path(args.input, "results"), out_file) f.savefig(out_file) print("Saved to", out_file) from matplotlib.lines import Line2D # Define custom legend handles custom_lines = [ Line2D([0], [0], color='orange', linestyle='-', label='LGATr'), Line2D([0], [0], color='green', linestyle='-', label='GATr'), Line2D([0], [0], color='blue', linestyle='-', label='Transformer'), Line2D([0], [0], color='gray', linestyle='-', label='AK8'), Line2D([0], [0], color='black', linestyle='-', label='reco'), Line2D([0], [0], color='black', linestyle=':', label='gen'), Line2D([0], [0], color='black', linestyle='--', label='parton'), ] if len(models): fig_steps, ax_steps = plt.subplots(len(m_Meds), len(r_invs), figsize=(sz_small * len(r_invs), sz_small * len(m_Meds))) if len(m_Meds) == 1 and len(r_invs) == 1: ax_steps = np.array([[ax_steps]]) histograms = {} for key in histograms_dict: if key not in histograms: histograms[key] = {} for i in ["pt", "eta", "phi"]: f, a = plt.subplots(len(m_Meds), len(r_invs), figsize=(sz_small * len(r_invs), sz_small * len(m_Meds))) if len(r_invs) == 1 and len(m_Meds) == 1: a = np.array([[a]]) histograms[key][i] = f, a colors = {"base_LGATr": "orange", "base_Tr": "blue", "base_GATr": "green", "AK8": "gray"} # THE COLORS FOR THE STEP VS. F1 SCORE #colors_small_dataset = {"base_LGATr_SD": "orange", "base_Tr_SD": "blue", "base_GATr_SD": "green", "AK8": "gray"} #colors = colors_small_dataset level_styles = {"scouting": "solid", "PL": "dashed", "GL": "dotted"} #step_to_plot_histograms = 50000 # phi, eta, pt histograms... level_to_plot_histograms = "scouting" for i, mMed_h in enumerate(m_Meds): for j, rInv_h in enumerate(r_invs): ax_steps[i, j].set_title("$m_{{Z'}} = {}$ GeV, $r_{{inv.}} = {}$".format(mMed_h, rInv_h)) ax_steps[i, j].set_xlabel("Training step") ax_steps[i, j].set_ylabel("Test $F_1$ score") #if j == 0: #ax_steps[i, j].set_ylabel("$m_{{Z'}} = {}$".format(mMed_h)) #for subset in histograms: #for key in histograms[subset]: #histograms[subset][key][1][i, j].set_ylabel("$m_{{Z'}} = {}$".format(mMed_h)) if i == len(m_Meds)-1: ax_steps[i, j].set_xlabel("$r_{{inv.}} = {}$".format(rInv_h)) for subset in histograms: for key in histograms[subset]: histograms[subset][key][1][i, j].set_xlabel("$r_{{inv.}} = {}$".format(rInv_h)) for model in jet_properties: if level_to_plot_histograms not in jet_properties[model][mMed_h][rInv_h]: print("Skipping", model, level_to_plot_histograms, " - levels:", jet_properties[model][mMed_h][rInv_h].keys()) continue for subset in histograms: for key in histograms[subset]: if model not in histograms_dict[subset][1]: continue step_to_plot_histograms = histograms_dict[subset][0][model] if step_to_plot_histograms not in jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms]: print("Swapping the step to plot histograms", jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms].keys()) step_to_plot_histograms = sorted(list(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms].keys()))[0] pred = np.array(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms][step_to_plot_histograms][key + "_pred"]) truth = np.array(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms][step_to_plot_histograms][key + "_gen_particle"]) if key.startswith("pt"): q = pred/truth symbol = "/" # division instead of subtraction symbol for pt quantity = "p_{T,pred}/p_{T,true}" bins = np.linspace(0, 2.5, 100) elif key.startswith("eta"): q = (pred - truth) symbol = "-" quantity="\eta_{pred}-\eta_{true}" bins = np.linspace(-0.8, 0.8, 50) elif key.startswith("phi"): q = pred - truth symbol = "-" quantity = "\phi_{pred}-\phi_{true}" q[q > np.pi] -= 2 * np.pi q[q< -np.pi] += 2 * np.pi bins = np.linspace(-0.8, 0.8, 50) print("Max", np.max(q), "Min", np.min(q)) rename = {"base_LGATr": "LGATr", "LGATr_GP_IRC_S_50k": "LGATr_GP_IRC_S", "AK8": "AK8", "LGATr_GP_50k": "LGATr_GP"} histograms[subset][key][1][i, j].hist(q, histtype="step", color=histograms_dict[subset][1][model], label=rename.get(model, model), bins=bins, density=True) if mMed_h > 0: histograms[subset][key][1][i, j].set_title(f"${quantity}$ $m_{{Z'}}={mMed_h}$ GeV, $r_{{inv.}}={rInv_h}$") else: histograms[subset][key][1][i, j].set_title(f"${quantity}$") histograms[subset][key][1][i, j].legend() histograms[subset][key][1][i, j].grid(True) for model in to_plot_steps: for lvl in to_plot_steps[model][mMed_h][rInv_h]: if model not in colors: print("Skipping", model) continue print(model) ls = level_styles[lvl] plt_dict = to_plot_steps[model][mMed_h][rInv_h][lvl] x_pts = sorted(list(plt_dict.keys())) y_pts = [plt_dict[k] for k in x_pts] if ls == "solid": ax_steps[i, j].plot(x_pts, y_pts, label=model, marker=".", linestyle=ls, color=colors[model]) else: # No label ax_steps[i, j].plot(x_pts, y_pts, marker=".", linestyle=ls, color=colors[model]) ax_steps[i, j].legend(handles=custom_lines) # now plot a horizontal line for the AKX same level if lvl == "scouting": rc = result_AKX_current elif lvl == "PL": rc = result_AKX_PL elif lvl == "GL": rc = result_AKX_GL else: raise Exception pr = rc[mMed_h][mDark][rInv_h][0] rec = rc[mMed_h][mDark][rInv_h][1] f1ak = 2 * pr * rec / (pr + rec) ax_steps[i, j].axhline(f1ak, color="gray", linestyle=ls, alpha=0.5) ax_steps[i, j].grid(1) path_steps_fig = os.path.join(get_path(args.input, "results"), "score_vs_step_plots.pdf") fig_steps.tight_layout() fig_steps.savefig(path_steps_fig) for subset in histograms: for key in histograms[subset]: fig = histograms[subset][key][0] fig.tight_layout() fig.savefig(os.path.join(get_path(args.input, "results"), "histogram_{}_{}.pdf".format(key, subset))) print("Saved to", path_steps_fig) '''for i, h in enumerate(plotting_hypotheses): mMed_h, rInv_h = h if rInv_h not in to_plot[td]: to_plot[td][rInv_h] = {} print("Model", model) if mMed_h not in to_plot[td][rInv_h]: to_plot[td][rInv_h][mMed_h] = {} # level if level not in to_plot[td][rInv_h][mMed_h]: to_plot[td][rInv_h][mMed_h][level] = {"precision": [], "recall": [], "f1score": [], "R": []} precision = result_PR[mMed_h][mDark][rInv_h][0] recall = result_PR[mMed_h][mDark][rInv_h][1] f1score = 2 * precision * recall / (precision + recall) to_plot[td][rInv_h][mMed_h][level]["precision"].append(precision) to_plot[td][rInv_h][mMed_h][level]["recall"].append(recall) to_plot[td][rInv_h][mMed_h][level]["f1score"].append(f1score) to_plot[td][rInv_h][mMed_h][level]["R"].append(r) ''' to_plot_ak = {} # level ("scouting"/"GL"/"PL") -> rInv -> mMed -> {"f1score": [], "R": []} for j, model in enumerate(["AKX", "AKX_PL", "AKX_GL"]): print(model) if os.path.exists(os.path.join(path, model, "count_matched_quarks", "result_PR_AKX.pkl")): result_PR_AKX = pickle.load(open(os.path.join(path, model, "count_matched_quarks", "result_PR_AKX.pkl"), "rb")) else: print("Skipping", model) continue level = "scouting" if "PL" in model: level = "PL" elif "GL" in model: level = "GL" if level not in to_plot_ak: to_plot_ak[level] = {} for mMed_h in result_PR_AKX: if mMed_h not in results_all_ak: results_all_ak[mMed_h] = {mDark: {}} for rInv_h in result_PR_AKX[mMed_h][mDark]: if rInv_h not in results_all_ak[mMed_h][mDark]: results_all_ak[mMed_h][mDark][rInv_h] = {} if level not in results_all_ak[mMed_h][mDark][rInv_h]: results_all_ak[mMed_h][mDark][rInv_h][level] = {} for ridx, R in enumerate(result_PR_AKX[mMed_h][mDark][rInv_h]): if R not in results_all_ak[mMed_h][mDark][rInv_h][level]: precision = result_PR_AKX[mMed_h][mDark][rInv_h][R][0] recall = result_PR_AKX[mMed_h][mDark][rInv_h][R][1] f1score = 2 * precision * recall / (precision + recall) results_all_ak[mMed_h][mDark][rInv_h][level][R] = f1score for i, h in enumerate(plotting_hypotheses): mMed_h, rInv_h = h if rInv_h not in to_plot_ak[level]: to_plot_ak[level][rInv_h] = {} print("Model", model) if mMed_h not in to_plot_ak[level][rInv_h]: to_plot_ak[level][rInv_h][mMed_h] = {"precision": [], "recall": [], "f1score": [], "R": []} rs = sorted(result_PR_AKX[mMed_h][mDark][rInv_h].keys()) precision = np.array([result_PR_AKX[mMed_h][mDark][rInv_h][i][0] for i in rs]) recall = np.array([result_PR_AKX[mMed_h][mDark][rInv_h][i][1] for i in rs]) f1score = 2 * precision * recall / (precision + recall) to_plot_ak[level][rInv_h][mMed_h]["precision"] = precision to_plot_ak[level][rInv_h][mMed_h]["recall"] = recall to_plot_ak[level][rInv_h][mMed_h]["f1score"] = f1score to_plot_ak[level][rInv_h][mMed_h]["R"] = rs print("AK:", to_plot_ak) fig, ax = plt.subplots(len(to_plot) + 1, len(plotting_hypotheses), figsize=(sz_small * len(plotting_hypotheses), sz_small * len(to_plot))) # also add AKX as last plot if len(to_plot) == 0: ax = np.array([ax]) colors = { #"PL": "green", #"GL": "blue", #"scouting": "red", "PL+ghosts": "green", "GL+ghosts": "blue", "scouting+ghosts": "red" } ak_colors = { "PL": "green", "GL": "blue", "scouting": "red", } ''' for i, td in enumerate(to_plot): # for each training dataset for j, h in enumerate(plotting_hypotheses): ax[i, j].set_title(f"r_inv={h[1]}, m={h[0]}, tr. on {td}") ax[i, j].set_ylabel("F1 score") ax[i, j].set_xlabel("GT R") ax[i, j].grid() for level in sorted(list(to_plot[td][h[1]][h[0]].keys())): print("level", level) print("Plotting", td, h[1], h[0], level) if level in colors: ax[i, j].plot(to_plot[td][h[1]][h[0]][level]["R"], to_plot[td][h[1]][h[0]][level]["f1score"], ".-", label=level, color=colors[level]) ax[i, j].legend() for j, h in enumerate(plotting_hypotheses): # for to_plot_AK ax[-1, j].set_title(f"r_inv={h[1]}, m={h[0]}, AK baseline") ax[-1, j].set_ylabel("F1 score") ax[-1, j].set_xlabel("GT R") ax[-1, j].grid() for i, ak_level in enumerate(sorted(list(to_plot_ak.keys()))): mMed_h, rInv_h = h if ak_level in ak_colors: ax[-1, j].plot(to_plot_ak[ak_level][rInv_h][mMed_h]["R"], to_plot_ak[ak_level][rInv_h][mMed_h]["f1score"], ".-", label=ak_level, color=ak_colors[ak_level]) ax[-1, j].legend() fig.tight_layout() fig.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_1.pdf")) print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_1.pdf")) ''' fig, ax = plt.subplots(1, len(results_all)*len(radius) + len(radius), figsize=(7 * len(results_all)*len(radius)+len(radius), 5)) for i, model in enumerate(results_all): for j, R in enumerate(radius): #if r not in results_all[model][700][20][0.3]["scouting"]: # continue # for each training dataset index = len(radius)*i + j ax[index].set_title(model + " R={}".format(R)) matrix_plot(results_all[model], "Greens", r"PL/GL F1 score", ax=ax[index], metric_comp_func=lambda r: r["PL+ghosts"][R]/r["scouting+ghosts"][R]) for i, R in enumerate(radius): index = len(radius)*len(results_all) + i ax[index].set_title("AK R={}".format(R)) matrix_plot(results_all_ak, "Greens", r"PL/GL F1 score", ax=ax[index], metric_comp_func=lambda r: r["PL"][R]/r["GL"][R]) fig.tight_layout() fig.savefig(out_file_PG) print("Saved to", out_file_PG) 1/0 #print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf")) #print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf")) ########### Now save the above plot with objectness score applied if args.threshold_obj_score != -1: fig, ax = plt.subplots(3, len(models), figsize=(sz * len(models), sz * 3)) for i, model in tqdm(enumerate(models)): output_path = os.path.join(path, model, "count_matched_quarks") if not os.path.exists(os.path.join(output_path, "result.pkl")): continue result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) #result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb")) result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) result_PR_thresholds = pickle.load(open(os.path.join(output_path, "result_PR_thresholds.pkl"), "rb")) #thresholds = sorted(list(result_PR_thresholds[900][20][0.3].keys())) #thresholds = np.array(thresholds) # now linearly interpolate the thresholds and set the j according to args.threshold_obj_score j = np.argmin(np.abs(thresholds - args.threshold_obj_score)) print("Thresholds", thresholds) print("Chose j=", j, "for threshold", args.threshold_obj_score, "(effectively it's", thresholds[j], ")") def wrap(r): # compute [precision, recall] array from [n_relevant_retrieved, all_retrieved, all_relevant] if r[1] == 0 or r[2] == 0: return [0, 0] return [r[0] / r[1], r[0] / r[2]] matrix_plot(result_PR_thresholds, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: wrap(r[j])[0], ax=ax[0, i]) matrix_plot(result_PR_thresholds, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: wrap(r[j])[1], ax=ax[1, i]) matrix_plot(result_PR_thresholds, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * wrap(r[j])[0] * wrap(r[j])[1] / (wrap(r[j])[0] + wrap(r[j])[1]), ax=ax[2, i]) ax[0, i].set_title(model) ax[1, i].set_title(model) ax[2, i].set_title(model) fig.tight_layout() fig.savefig(out_file_PR_OS) print("Saved to", out_file_PR_OS) ################ # UNUSED PLOTS # ################ '''fig, ax = plt.subplots(2, len(models), figsize=(sz * len(models), sz * 2)) for i, model in tqdm(enumerate(models)): output_path = os.path.join(path, model, "count_matched_quarks") if not os.path.exists(os.path.join(output_path, "result.pkl")): continue result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) #result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb")) result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) matrix_plot(result, "Blues", "Avg. matched dark quarks / event", ax=ax[0, i]) matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event", ax=ax[1, i]) ax[0, i].set_title(model) ax[1, i].set_title(model) fig.tight_layout() fig.savefig(out_file_avg_number_matched_quarks) print("Saved to", out_file_avg_number_matched_quarks)''' rinvs = [0.3, 0.5, 0.7] sz = 4 fig, ax = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs))) fig_AK, ax_AK = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs))) fig_AK_ratio, ax_AK_ratio = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs))) to_plot = {} # r_inv -> m_med -> precision, recall, R to_plot_ak = {} # plotting for the AK baseline ### Plotting the score vs GT R plots oranges = plt.get_cmap("Oranges") reds = plt.get_cmap("Reds") # Plot a plot for each mass at given r_inv of the precision, recall, F1 score purples = plt.get_cmap("Purples") mDark = 20 for i, rinv in enumerate(rinvs): if rinv not in to_plot: to_plot[rinv] = {} to_plot_ak[rinv] = {} for j, model in enumerate(models): print("Model", model) if model not in radius: continue r = radius[model] output_path = os.path.join(path, model, "count_matched_quarks") if not os.path.exists(os.path.join(output_path, "result_PR.pkl")): continue result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) #if radius not in to_plot[rinv]: # to_plot[rinv][radius] = {} for k, mMed in enumerate(sorted(result_PR.keys())): if mMed not in to_plot[rinv]: to_plot[rinv][mMed] = {"precision": [], "recall": [], "f1score": [], "R": []} precision = result_PR[mMed][mDark][rinv][0] recall = result_PR[mMed][mDark][rinv][1] f1score = 2 * precision * recall / (precision + recall) to_plot[rinv][mMed]["precision"].append(precision) to_plot[rinv][mMed]["recall"].append(recall) to_plot[rinv][mMed]["f1score"].append(f1score) to_plot[rinv][mMed]["R"].append(r) for mMed in sorted(to_plot[rinv].keys()): # normalize mmed between 0 and 1 (originally between 700 and 3000) mmed = (mMed - 500) / (3000 - 500) r = to_plot[rinv][mMed] print("Model R", r["R"]) scatter_plot(ax[0, i], r["R"], r["precision"], label="m={} GeV".format(round(mMed)), color=oranges(mmed)) scatter_plot(ax[1, i], r["R"], r["recall"], label="m={} GeV".format(round(mMed)), color=reds(mmed)) scatter_plot(ax[2, i], r["R"], r["f1score"], label="m={} GeV".format(round(mMed)), color=purples(mmed)) if not os.path.exists(os.path.join(ak_path, "result_PR_AKX.pkl")): continue result_PR_AKX = pickle.load(open(os.path.join(ak_path, "result_PR_AKX.pkl"), "rb")) result_jet_props_akx = pickle.load(open(os.path.join(ak_path, "result_jet_properties_AKX.pkl"), "rb")) #if radius not in to_plot[rinv]: # to_plot[rinv][radius] = {} for k, mMed in enumerate(sorted(result_PR_AKX.keys())): if mMed not in to_plot_ak[rinv]: to_plot_ak[rinv][mMed] = {"precision": [], "recall": [], "f1score": [], "R": []} rs = sorted(result_PR_AKX[mMed][mDark][rinv].keys()) precision = np.array([result_PR_AKX[mMed][mDark][rinv][k][0] for k in rs]) recall = np.array([result_PR_AKX[mMed][mDark][rinv][k][1] for k in rs]) f1score = 2 * precision * recall / (precision + recall) to_plot_ak[rinv][mMed]["precision"] += list(precision) to_plot_ak[rinv][mMed]["recall"] += list(recall) to_plot_ak[rinv][mMed]["f1score"] += list(f1score) to_plot_ak[rinv][mMed]["R"] += rs for mMed in sorted(to_plot_ak[rinv].keys()): # Normalize mmed between 0 and 1 (originally between 700 and 3000) mmed = (mMed - 500) / (3000 - 500) r = to_plot_ak[rinv][mMed] r_model = to_plot[rinv][mMed] print("AK R", r["R"]) scatter_plot(ax_AK[0, i], r["R"], r["precision"], label="m={} GeV AK".format(round(mMed)), color=oranges(mmed), pattern=".--") scatter_plot(ax_AK[1, i], r["R"], r["recall"], label="m={} GeV AK".format(round(mMed)), color=reds(mmed), pattern=".--") scatter_plot(ax_AK[2, i], r["R"], r["f1score"], label="m={} GeV AK".format(round(mMed)), color=purples(mmed), pattern=".--") # r["R"] has more points than r_model["R"] - pick those from r["R"] that are in r_model["R"] r["R"] = np.array(r["R"]) r["precision"] = np.array(r["precision"]) r["recall"] = np.array(r["recall"]) r["f1score"] = np.array(r["f1score"]) filt = np.isin(r["R"], r_model["R"]) r["R"] = r["R"][filt] r["precision"] = r["precision"][filt] r["recall"] = r["recall"][filt] r["f1score"] = r["f1score"][filt] scatter_plot(ax_AK_ratio[0, i], r["R"], r["precision"]/np.array(r_model["precision"]), label="m={} GeV AK".format(round(mMed)), color=oranges(mmed), pattern=".--") scatter_plot(ax_AK_ratio[1, i], r["R"], r["recall"]/np.array(r_model["recall"]), label="m={} GeV AK".format(round(mMed)), color=reds(mmed), pattern=".--") scatter_plot(ax_AK_ratio[2, i], r["R"], r["f1score"]/np.array(r_model["f1score"]), label="m={} GeV AK".format(round(mMed)), color=purples(mmed), pattern=".--") for ax1 in [ax, ax_AK, ax_AK_ratio]: ax1[0, i].set_title(f"Precision r_inv = {rinv}") ax1[1, i].set_title(f"Recall r_inv = {rinv}") ax1[2, i].set_title(f"F1 score r_inv = {rinv}") ax1[2, i].legend() ax1[1, i].legend() ax1[0, i].legend() ax1[0, i].grid() ax1[1, i].grid() ax1[2, i].grid() ax1[0, i].set_xlabel("GT R") ax1[1, i].set_xlabel("GT R") ax1[2, i].set_xlabel("GT R") ax_AK_ratio[0, i].set_ylabel("Precision (model=1)") ax_AK_ratio[1, i].set_ylabel("Recall (model=1)") ax_AK_ratio[2, i].set_ylabel("F1 score (model=1)") fig.tight_layout() fig_AK.tight_layout() fig.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots.pdf")) fig_AK.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf")) fig_AK_ratio.tight_layout() fig_AK_ratio.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf")) print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf")) print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf"))