Spaces:
Sleeping
Sleeping
import torch | |
from itertools import chain, combinations | |
import os | |
from tqdm import tqdm | |
import argparse | |
import pickle | |
from src.plotting.eval_matrix import matrix_plot, scatter_plot, multiple_matrix_plot, ax_tiny_histogram | |
from src.utils.paths import get_path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from collections import OrderedDict | |
#### Plotting functions | |
from matplotlib_venn import venn3 | |
import matplotlib.pyplot as plt | |
from copy import copy | |
def plot_venn3_from_index_dict(ax, data_dict, set_labels=('Set 0', 'Set 1', 'Set 2'), set_colors=("orange", "purple", "gray"), remove_max=True): | |
""" | |
Generate a 3-set Venn diagram from a dictionary where keys are strings of '0', '1', and '2' | |
indicating set membership, and values are counts. | |
Parameters: | |
- data_dict (dict): Dictionary with keys like '0', '01', '012', etc. | |
- set_labels (tuple): Labels for the three sets. | |
- remove_max: if true, it will remove | |
""" | |
# Mapping of set index combinations to venn3 region codes | |
index_to_region = { | |
'100': '100', # Only in Set 0 | |
'010': '010', # Only in Set 1 | |
'001': '001', # Only in Set 2 | |
'110': '110', # In Set 0 and Set 1 | |
'101': '101', # In Set 0 and Set 2 | |
'011': '011', # In Set 1 and Set 2 | |
'111': '111', # In all three | |
} | |
# Initialize region counts | |
venn_counts = {region: 0 for region in index_to_region.values()} | |
max_value = 0 | |
for key in data_dict: | |
if data_dict[key] > max_value and key != "": | |
max_value = data_dict[key] | |
print("Max val", max_value) | |
data_dict = copy(data_dict) | |
new_data_dict = {} | |
for key in data_dict: | |
if remove_max and data_dict[key] == max_value: | |
# #data_dict[key] = 0 | |
# del data_dict[key] | |
continue | |
else: | |
new_data_dict[key] = data_dict[key] | |
data_dict = new_data_dict | |
print("data dict", data_dict) | |
# Convert data_dict keys to binary keys for region mapping | |
for k, v in data_dict.items(): | |
binary_key = ''.join(['1' if str(i) in k else '0' for i in range(3)]) | |
if binary_key in index_to_region: | |
venn_counts[index_to_region[binary_key]] += v | |
# Plotting | |
#plt.figure(figsize=(8, 8)) | |
del venn_counts['111'] | |
venn = venn3(subsets=venn_counts, set_labels=set_labels, set_colors=set_colors, alpha=0.5, ax=ax) | |
venn.get_label_by_id("111").set_text(max_value) | |
#plt.title("3-Set Venn Diagram from Index Dictionary") | |
#plt.show() | |
### Change this to make custom plots highlighting differences between different models (the histograms of pt_pred/pt_true, eta_pred-eta_true, and phi_pred-phi_true) | |
histograms_dict = { | |
"": [{"base_LGATr": 50000, "base_Tr": 50000 , "base_GATr": 50000, "AK8": 50000}, {"base_LGATr": "orange", "base_Tr": "blue", "base_GATr": "green", "AK8": "gray"}], | |
"LGATr_comparison": [{"base_LGATr": 50000, "LGATr_GP_IRC_S_50k": 9960, "LGATr_GP_50k": 9960, "AK8": 50000, "LGATr_GP_IRC_SN_50k": 24000}, {"base_LGATr": "orange", "LGATr_GP_IRC_S_50k": "red", "LGATr_GP_50k": "purple", "LGATr_GP_IRC_SN_50k": "pink", "AK8": "gray"}], | |
"LGATr_comparsion_DifferentTrainingDS": [{"base_LGATr": 50000, "LGATr_700_07": 50000, "LGATr_QCD": 50000, "LGATr_700_07+900_03": 50000, "LGATr_700_07+900_03+QCD": 50000, "AK8": 50000}, {"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", "LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}] | |
} | |
# This is a dictionary that contains the models and their colors for plotting - to plot the F1 scores etc. of the models | |
results_dict = { | |
"LGATr_comparison_DifferentTrainingDS": | |
[{"base_LGATr": "orange", "LGATr_700_07": "red", "LGATr_QCD": "purple", "LGATr_700_07+900_03": "blue", | |
"LGATr_700_07+900_03+QCD": "green", "AK8": "gray"}, {"base_LGATr": "LGATr_900_03"}], # 2nd dict in list is rename dict | |
"LGATr_comparison": [{"base_LGATr": "orange", "LGATr_GP_IRC_S_50k": "red", "LGATr_GP_50k": "purple", "LGATr_GP_IRC_SN_50k": "pink", "AK8": "gray"}, | |
{"base_LGATr": "LGATr", "LGATr_GP_IRC_S_50k": "LGATr_GP_IRC_S", "LGATr_GP_50k": "LGATr_GP", "LGATr_GP_IRC_SN_50k": "LGATr_GP_IRC_SN"}], # 2nd dict in list is rename dict | |
"LGATr_comparison_QCDtrain": [{"LGATr_QCD": "orange", "LGATr_GP_IRC_S_QCD": "red", "LGATr_GP_QCD": "purple", "LGATr_GP_IRC_SN_QCD": "pink", "AK8": "gray"}, | |
{"LGATr_QCD": "LGATr", "LGATr_GP_IRC_S_QCD": "LGATr_GP_IRC_S", "LGATr_GP_QCD": "LGATr_GP", "LGATr_GP_IRC_SN_QCD": "LGATr_GP_IRC_SN"}], # 2nd dict in list is rename dict | |
"LGATr_comparison_GP_training": [ | |
{"LGATr_GP_QCD": "purple", "LGATr_GP_700_07": "red", "LGATr_GP_700_07+900_03": "blue", "LGATr_GP_700_07+900_03+QCD": "green", "LGATr_GP_50k": "orange", "AK8": "gray"}, | |
{"LGATr_GP_QCD": "QCD", "LGATr_GP_700_07": "700_07", "LGATr_GP_700_07+900_03": "700_07+900_03" , "LGATr_GP_50k": "900_03", "LGATr_GP_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict | |
], | |
"LGATr_comparison_GP_IRC_S_training": [ | |
{"LGATr_GP_IRC_S_QCD": "purple", "LGATr_GP_IRC_S_700_07": "red", "LGATr_GP_IRC_S_700_07+900_03": "blue", "LGATr_GP_IRC_S_700_07+900_03+QCD": "green", "LGATr_GP_IRC_S_50k": "orange", "AK8": "gray"}, | |
{"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_700_07": "700_07", "LGATr_GP_IRC_S_700_07+900_03": "700_07+900_03", "LGATr_GP_IRC_S_50k": "900_03", "LGATr_GP_IRC_S_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict | |
], | |
"LGATr_comparison_GP_IRC_SN_training": [ | |
{"LGATr_GP_IRC_SN_QCD": "purple", "LGATr_GP_IRC_SN_700_07": "red", "LGATr_GP_IRC_SN_700_07+900_03": "blue", "LGATr_GP_IRC_SN_700_07+900_03+QCD": "green", "LGATr_GP_IRC_SN_50k": "orange", "AK8": "gray"}, | |
{"LGATr_GP_IRC_SN_QCD": "QCD", "LGATr_GP_IRC_SN_700_07": "700_07", "LGATr_GP_IRC_SN_700_07+900_03": "700_07+900_03", "LGATr_GP_IRC_SN_50k": "900_03", "LGATr_GP_IRC_SN_700_07+900_03+QCD": "700_07+900_03+QCD"} # 2nd dict in list is rename dict | |
] | |
} | |
''' | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_19_21_29_06_946": "LGATr_GP_QCD", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_19_21_38_20_376": "LGATr_GP_700_07", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_13_12_54_359": "LGATr_GP_700_07+900_03+QCD", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_13_13_00_503": "LGATr_GP_700_07+900_03", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_15_29_30_29": "LGATr_GP_IRC_S_700_07+900_03", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_15_29_28_959": "LGATr_GP_IRC_S_700_07+900_03+QCD", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_20_15_11_35_476": "LGATr_GP_IRC_S_700_07", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_20_15_11_20_735": "LGATr_GP_IRC_S_QCD", | |
''' | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", type=str, required=False, default="scouting_PFNano_signals2/SVJ_hadronic_std/batch_eval/objectness_score") | |
parser.add_argument("--threshold-obj-score", "-os-threshold", type=float, default=-1) | |
thresholds = np.linspace(0.1, 1, 20) | |
# also add 100 points between 0 and 0.1 at the beginning | |
thresholds = np.concatenate([np.linspace(0, 0.1, 100), thresholds]) | |
args = parser.parse_args() | |
path = get_path(args.input, "results") | |
models = sorted([x for x in os.listdir(path) if not os.path.isfile(os.path.join(path, x))]) | |
models = [x for x in models if "AKX" not in x] | |
figures_all = {} # title to the f1 score figure to plot | |
figures_all_sorted = {} # model used -> step -> level -> f1 figure | |
print("Models:", models) | |
radius = { | |
"LGATr_R10": 1.0, | |
"LGATr_R09": 0.9, | |
"LGATr_rinv_03_m_900": 0.8, | |
"LGATr_R06": 0.6, | |
"LGATr_R07": 0.7, | |
"LGATr_R11": 1.1, | |
"LGATr_R12": 1.2, | |
"LGATr_R13": 1.3, | |
"LGATr_R14": 1.4, | |
"LGATr_R20": 2.0, | |
"LGATr_R25": 2.5 | |
} | |
comments = { | |
"Eval_params_study_2025_02_17_13_30_50": ", tr. on 07_700", | |
"Eval_objectness_score_2025_02_12_15_34_33": ", tr. on 03_900, GT=all", | |
"Eval_objectness_score_2025_02_18_08_48_13": ", tr. on 03_900, GT=closest", | |
"Eval_objectness_score_2025_02_14_11_10_14": ", tr. on 03_900, GT=closest", | |
"Eval_objectness_score_2025_02_21_14_51_07": ", tr. on 07_700", | |
"Eval_objectness_score_2025_02_10_14_59_49": ", tr. on 03_900, GT=all", | |
"Eval_objectness_score_2025_02_23_19_26_25": ", tr. on all, GT=closest", | |
"Eval_objectness_score_2025_02_23_21_04_33": ", tr. on 03_900, GT=closest" | |
} | |
out_file_PR = os.path.join(get_path(args.input, "results"), "precision_recall.pdf") | |
out_file_PRf1 = os.path.join(get_path(args.input, "results"), "f1_score_sorted.pdf") | |
out_file_PG = os.path.join(get_path(args.input, "results"), "PLoverGL.pdf") | |
if args.threshold_obj_score != -1: | |
out_file_PR_OS = os.path.join(get_path(args.input, "results"), f"precision_recall_with_obj_score.pdf") | |
out_file_avg_number_matched_quarks = os.path.join(get_path(args.input, "results"), "avg_number_matched_quarks.pdf") | |
def get_plots_for_params(mMed, mDark, rInv, result_PR_thresholds): | |
precisions = [] | |
recalls = [] | |
f1_scores = [] | |
for i in range(len(thresholds)): | |
if result_PR_thresholds[mMed][mDark][rInv][i][1] == 0: | |
precisions.append(0) | |
else: | |
precisions.append( | |
result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][1]) | |
if result_PR_thresholds[mMed][mDark][rInv][i][2] == 0: | |
recalls.append(0) | |
else: | |
recalls.append( | |
result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][2]) | |
for i in range(len(thresholds)): | |
if precisions[i] + recalls[i] == 0: | |
f1_scores.append(0) | |
else: | |
f1_scores.append(2 * precisions[i] * recalls[i] / (precisions[i] + recalls[i])) | |
return precisions, recalls, f1_scores | |
sz = 5 | |
nplots = 9 | |
# Now make 3 plots, one for mMed=700,r_inv=0.7; one for mMed=700,r_inv=0.5; one for mMed=700,r_inv=0.3 | |
###fig, ax = plt.subplots(3, 3, figsize=(3 * sz, 3 * sz)) | |
'''fig, ax = plt.subplots(3, nplots, figsize=(nplots*sz, 3*sz)) | |
for mi, mass in enumerate([700, 900, 1500]): | |
start_idx = mi*3 | |
for i0, rinv in enumerate([0.3, 0.5, 0.7]): | |
i = start_idx + i0 | |
# 0 is precision, 1 is recall, 2 is f1 score | |
ax[0, i].set_title(f"r_inv={rinv}, m_med={mass} GeV") | |
ax[1, i].set_title(f"r_inv={rinv}, m_med={mass} GeV") | |
ax[2, i].set_title(f"r_inv={rinv}, m_med={mass} GeV") | |
ax[0, i].set_ylabel("Precision") | |
ax[1, i].set_ylabel("Recall") | |
ax[2, i].set_ylabel("F1 score") | |
ax[0, i].grid() | |
ax[1, i].grid() | |
ax[2, i].grid() | |
ylims = {} # key: j and i | |
default_ylims = [1, 0] | |
for j, model in enumerate(models): | |
result_PR_thresholds = os.path.join(path, model, "count_matched_quarks", "result_PR_thresholds.pkl") | |
if not os.path.exists(result_PR_thresholds): | |
continue | |
run_config = pickle.load(open(os.path.join(path, model, "run_config.pkl"), "rb")) | |
result_PR_thresholds = pickle.load(open(result_PR_thresholds, "rb")) | |
if mass not in result_PR_thresholds: | |
continue | |
if rinv not in result_PR_thresholds[mass]: | |
continue6 | |
precisions, recalls, f1_scores = get_plots_for_params(mass, 20, rinv, result_PR_thresholds) | |
if not run_config["gt_radius"] == 0.8: | |
continue | |
label = "R={} gl.f.={} {}".format(run_config["gt_radius"], run_config.get("global_features_obj_score", False), comments.get(run_config["run_name"], run_config["run_name"])) | |
scatter_plot(ax[0, i], thresholds, precisions, label=label) | |
scatter_plot(ax[1, i], thresholds, recalls, label=label) | |
scatter_plot(ax[2, i], thresholds, f1_scores, label=label) | |
#ylims[0] = [min(ylims[0][0], min(precisions)), max(ylims[0][1], max(precisions))] | |
#ylims[1] = [min(ylims[1][0], min(recalls)), max(ylims[1][1], max(recalls))] | |
#ylims[2] = [min(ylims[2][0], min(f1_scores)), max(ylims[2][1], max(f1_scores))] | |
filt = thresholds < 0.2 | |
precisions = np.array(precisions)[filt] | |
recalls = np.array(recalls)[filt] | |
f1_scores = np.array(f1_scores)[filt] | |
if (i, 0) not in ylims: | |
ylims[(i, 0)] = default_ylims | |
upper_factor = 1.01 | |
lower_factor = 0.99 | |
ylims[(i, 0)] = [min(ylims[(i, 0)][0], min(precisions)*lower_factor), max(ylims[(i, 0)][1], max(precisions)*upper_factor)] | |
if (i, 1) not in ylims: | |
ylims[(i, 1)] = default_ylims | |
ylims[(i, 1)] = [min(ylims[(i, 1)][0], min(recalls)*lower_factor), max(ylims[(i, 1)][1], max(recalls)*upper_factor)] | |
if (i, 2) not in ylims: | |
ylims[(i, 2)] = default_ylims | |
ylims[(i, 2)] = [min(ylims[(i, 2)][0], min(f1_scores)*lower_factor), max(ylims[(i, 2)][1], max(f1_scores)*upper_factor)] | |
for j in range(3): | |
ax[j, i].set_ylim(ylims[(i, j)]) | |
ax[j, i].legend() | |
ax[j, i].set_xlim([0, 0.2]) | |
ax[j, i].set_xlim([0, 0.2]) | |
ax[j, i].set_xlim([0, 0.2]) | |
# now adjust the ylim so that the plots are more readable | |
fig.tight_layout() | |
fig.savefig(os.path.join(get_path(args.input, "results"), "precision_recall_thresholds.pdf")) | |
print("Saved to", os.path.join(get_path(args.input, "results"), "precision_recall_thresholds.pdf"))''' | |
import wandb | |
api = wandb.Api() | |
def get_run_by_name(name): | |
clust_suffix = "" | |
if name.endswith("FT"): | |
#remove FT from the end | |
name = name[:-2] | |
clust_suffix = "FT" | |
if name.endswith("FT1"): | |
#remove FT from the end # min-samples 1 min-cluster-size 2 epsilon 0.3 | |
name = name[:-3] | |
clust_suffix = "FT1" | |
if name.endswith("10_5"): | |
name = name[:-4] | |
clust_suffix = "10_5" | |
runs = api.runs( | |
path="fcc_ml/svj_clustering", | |
filters={"display_name": {"$eq": name.strip()}} | |
) | |
runs = api.runs( | |
path="fcc_ml/svj_clustering", | |
filters={"display_name": {"$eq": name.strip()}} | |
) | |
if runs.length != 1: | |
return None | |
return runs[0], clust_suffix | |
def get_run_config(run_name): | |
r, clust_suffix = get_run_by_name(run_name) | |
if r is None: | |
print("Getting info from run", run_name, "failed") | |
return None, None | |
config = r.config | |
result = {} | |
if config["parton_level"]: | |
prefix = "PL" | |
result["level"] = "PL" | |
result["level_idx"] = 0 | |
elif config["gen_level"]: | |
prefix = "GL" | |
result["level"] = "GL" | |
result["level_idx"] = 2 | |
else: | |
prefix = "sc." | |
result["level"] = "scouting" | |
result["level_idx"] = 1 | |
if config["augment_soft_particles"]: | |
result["ghosts"] = True | |
#result["level"] += "+ghosts" | |
gt_r = config["gt_radius"] | |
if config.get("augment_soft_particles", False): | |
prefix += " (aug)" # ["LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_25_820", "LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_26_400"] | |
training_datasets = { | |
"LGATr_training_NoPID_10_16_64_0.8_AllData_2025_02_28_13_42_59": "all", | |
"LGATr_training_NoPID_10_16_64_0.8_2025_02_28_12_42_59": "900_03", | |
"LGATr_training_NoPID_10_16_64_2.0_2025_02_28_12_48_58": "900_03", | |
"LGATr_training_NoPID_10_16_64_0.8_700_07_2025_02_28_13_01_59": "700_07", | |
"LGATr_training_NoPIDGL_10_16_64_0.8_2025_03_17_20_05_04": "900_03_GenLevel", | |
"LGATr_training_NoPIDGL_10_16_64_2.0_2025_03_17_20_05_04": "900_03_GenLevel", | |
"Transformer_training_NoPID_10_16_64_2.0_2025_03_03_17_00_38": "900_03_T", | |
"Transformer_training_NoPID_10_16_64_0.8_2025_03_03_15_55_50": "900_03_T", | |
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_2025_03_27_12_46_12_740": "900_03+SoftAug", | |
"LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_2025_03_28_10_43_36_81": "900_03+SoftAugVM", | |
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_2025_03_28_10_43_37_44": "900_03+SoftAugVM", | |
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_25_820": "900_03+qcap05", | |
"LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_2025_03_28_17_12_26_400": "900_03+qcap05", | |
"LGATr_training_NoPID_10_16_64_2.0_Aug_Finetune_vanishing_momentum_QCap05_1e-2_2025_03_29_14_58_38_650": "pt 1e-2", | |
"LGATr_training_NoPID_10_16_64_0.8_Aug_Finetune_vanishing_momentum_QCap05_1e-2_2025_03_29_14_58_36_446": "pt 1e-2", | |
"LGATr_pt_1e-2_500part_2025_04_01_16_49_08_406": "500_pt_1e-2_PLFT", | |
"LGATr_pt_1e-2_500part_2025_04_01_21_14_07_350": "500_pt_1e-2_PLFT", | |
"LGATr_pt_1e-2_500part_NoQMin_2025_04_03_23_15_17_745": "500_1e-2_scFT", | |
"LGATr_pt_1e-2_500part_NoQMin_2025_04_03_23_15_35_810": "500_1e-2_scFT", | |
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_2025_04_04_12_57_51_536": "10_1000_1e-2_scFT", | |
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_2025_04_04_12_57_47_788": "10_1000_1e-2_scFT", | |
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_CW0_2025_04_04_15_30_16_839": "10_1000_1e-2_CW0", | |
"LGATr_pt_1e-2_500part_NoQMin_10_to_1000p_CW0_2025_04_04_15_30_20_113": "10_1000_1e-2_CW0", | |
"debug_IRC_loss_weighted100_plus_ghosts_2025_04_08_22_40_33_972": "IRC_short_debug", | |
"debug_IRC_loss_weighted100_plus_ghosts_2025_04_09_13_48_55_569": "IRC", | |
"debug_IRC_loss_weighted100_plus_ghosts_Qmin05_2025_04_09_14_45_51_381": "IRC_qmin05", | |
"LGATr_500part_NOQMin_2025_04_09_21_53_37_210": "500part_NOQMin_reprod", | |
"IRC_loss_Split_and_Noise_alternate_Aug_2025_04_14_11_10_21_788": "IRC_Aug_S+N", | |
"IRC_loss_Split_and_Noise_alternate_NoAug_2025_04_11_16_15_48_955": "IRC_S+N", | |
"LGATr_training_NoPID_Delphes_10_16_64_0.8_2025_04_17_18_07_38_405": "DelphesTrain", | |
"Delphes_IRC_aug_2025_04_19_11_16_17_130": "DelphesTrain+IRC", | |
"LGATr_500part_NOQMin_Delphes_2025_04_19_11_15_24_417": "DelphesTrain+ghosts", | |
"Delphes_IRC_aug_SplitOnly_2025_04_20_15_50_33_553": "DelphesTrain+IRC_SplitOnly", | |
"Delphes_IRC_NOAug_SplitOnly_2025_04_21_12_58_36_99": "Delphes_IRC_NoAug_SplitOnly", | |
"Delphes_IRC_NOAug_SplitAndNoise_2025_04_21_19_32_08_865": "Delphes_IRC_NoAug_S+N", | |
"CONT_Delphes_IRC_aug_SplitOnly_2025_04_21_12_53_27_730": "IRC_aug_SplitOnly_ContFrom14k", | |
"Transformer_training_NoPID_Delphes_PU_10_16_64_0.8_2025_05_03_18_37_01_188": "base_Tr_Old", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_10_16_64_0.8_2025_05_03_18_35_53_134": "base_LGATr", | |
"GATr_training_NoPID_Delphes_PU_10_16_64_0.8_2025_05_03_18_35_48_163": "base_GATr_Old", | |
"Transformer_training_NoPID_Delphes_PU_CoordFix_10_16_64_0.8_2025_05_05_13_05_20_755": "base_Tr", | |
"GATr_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_05_16_24_13_579": "base_GATr_SD", | |
"GATr_training_NoPID_Delphes_PU_CoordFix_10_16_64_0.8_2025_05_05_13_06_27_898": "base_GATr", | |
"LGATr_Aug_2025_05_06_10_08_05_956": "LGATr_GP", | |
"Delphes_Aug_IRCSplit_CONT_2025_05_07_11_00_18_422": "LGATr_GP_IRC_S", | |
"Delphes_Aug_IRC_Split_and_Noise_2025_05_07_14_43_13_968": "LGATr_GP_IRC_SN", | |
"Transformer_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_05_16_24_19_936": "base_Tr_SD", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_SmallDS_10_16_64_0.8_2025_05_05_16_24_16_127": "base_LGATr_SD", | |
"Delphes_Aug_IRCSplit_2025_05_06_10_09_00_567": "LGATr_GP_IRC_S", | |
"GATr_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_09_15_34_13_531": "base_GATr_SD", | |
"Transformer_training_NoPID_Delphes_PU_CoordFix_SmallDS_10_16_64_0.8_2025_05_09_15_56_50_216": "base_Tr_SD", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_SmallDS_10_16_64_0.8_2025_05_09_15_56_50_875": "base_LGATr_SD", | |
"Delphes_Aug_IRCSplit_50k_from10k_2025_05_11_14_08_49_675": "LGATr_GP_IRC_S_50k", | |
"LGATr_Aug_50k_2025_05_09_15_25_32_34": "LGATr_GP_50k", | |
"Delphes_Aug_IRCSplit_50k_2025_05_09_15_22_38_956": "LGATr_GP_IRC_S_50k", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_16_21_04_26_937": "LGATr_700_07+900_03+QCD", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_16_21_04_26_991": "LGATr_700_07+900_03", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_16_19_46_57_48": "LGATr_QCD", | |
"LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_16_19_44_46_795": "LGATr_700_07", | |
"Delphes_Aug_IRCSplit_50k_SN_from3kFT_2025_05_16_14_07_29_474": "LGATr_GP_IRC_SN_50k", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_19_21_29_06_946": "LGATr_GP_QCD", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_19_21_38_20_376": "LGATr_GP_700_07", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_13_12_54_359": "LGATr_GP_700_07+900_03+QCD", | |
"GP_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_13_13_00_503": "LGATr_GP_700_07+900_03", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_20_15_29_30_29": "LGATr_GP_IRC_S_700_07+900_03", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_20_15_29_28_959": "LGATr_GP_IRC_S_700_07+900_03+QCD", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_20_15_11_35_476": "LGATr_GP_IRC_S_700_07", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_20_15_11_20_735": "LGATr_GP_IRC_S_QCD", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_QCD_events_10_16_64_0.8_2025_05_24_23_00_54_948": "LGATr_GP_IRC_SN_QCD", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_AND_QCD_10_16_64_0.8_2025_05_24_23_00_56_910": "LGATr_GP_IRC_SN_700_07+900_03+QCD", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_AND_900_03_10_16_64_0.8_2025_05_24_23_01_01_212": "LGATr_GP_IRC_SN_700_07+900_03", | |
"GP_IRC_S_LGATr_training_NoPID_Delphes_PU_PFfix_700_07_10_16_64_0.8_2025_05_24_23_01_07_703": "LGATr_GP_IRC_SN_700_07", | |
} | |
train_name = config["load_from_run"] | |
ckpt_step = config["ckpt_step"] | |
print("train name", train_name) | |
if train_name not in training_datasets: | |
print("!! unknown run", train_name) | |
training_dataset = training_datasets.get(train_name, train_name) + "_s" + str(ckpt_step) + clust_suffix | |
if "plptfilt01" in run_name.lower(): | |
training_dataset += "_PLPtFiltMinPt01" # min pt 0.1 | |
elif "noplfilter" in run_name.lower(): | |
training_dataset += "_noPLFilter" | |
elif "noplptfilter" in run_name.lower(): | |
training_dataset += "_noPLPtFilter" # actually there was a 0.5 pt cut in the ntuplizer, removed by plptfilt01 | |
elif "nopletafilter" in run_name.lower(): | |
training_dataset += "_noPLEtaFilter" | |
result["GT_R"] = gt_r | |
result["training_dataset"] = training_dataset | |
result["training_dataset_nostep"] = training_datasets.get(train_name, train_name) + clust_suffix | |
result["ckpt_step"] = ckpt_step | |
return f"GT_R={gt_r} {training_dataset}, {prefix}", result | |
def flatten_list(lst):# lst is like [[0,0],[1,1]...] | |
#return [item for sublist in lst for item in sublist] | |
return list(chain.from_iterable(lst)) | |
sz = 5 | |
ak_path = os.path.join(path, "AKX", "count_matched_quarks") | |
result_PR_AKX = pickle.load(open(os.path.join(ak_path, "result_PR_AKX.pkl"), "rb")) | |
result_jet_props_akx = pickle.load(open(os.path.join(ak_path, "result_jet_properties_AKX.pkl"), "rb")) | |
result_qj_akx = pickle.load(open(os.path.join(ak_path, "result_quark_to_jet.pkl"), "rb")) | |
result_dq_pt_akx = pickle.load(open(os.path.join(ak_path, "result_pt_dq.pkl"), "rb")) | |
result_dq_mc_pt_akx = pickle.load(open(os.path.join(ak_path, "result_pt_mc_gt.pkl"), "rb")) | |
result_dq_props_akx = pickle.load(open(os.path.join(ak_path, "result_props_dq.pkl"), "rb")) | |
try: | |
result_PR_AKX_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_PR_AKX.pkl"), "rb")) | |
result_qj_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_quark_to_jet.pkl"), "rb")) | |
result_dq_mc_pt_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_pt_mc_gt.pkl"), "rb")) | |
result_dq_pt_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_pt_dq.pkl"), "rb")) | |
result_dq_props_akx_PL = pickle.load(open(os.path.join(os.path.join(path, "AKX_PL", "count_matched_quarks"), "result_props_dq.pkl"), "rb")) | |
except FileNotFoundError: | |
print("FileNotFoundError") | |
result_PR_AKX_PL = result_PR_AKX | |
try: | |
result_PR_AKX_GL = pickle.load(open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_PR_AKX.pkl"), "rb")) | |
result_qj_akx_GL = pickle.load(open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_quark_to_jet.pkl"), "rb")) | |
result_dq_mc_pt_akx_GL = pickle.load( | |
open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_pt_mc_gt.pkl"), "rb")) | |
result_dq_pt_akx_GL = pickle.load( | |
open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_pt_dq.pkl"), "rb")) | |
result_dq_props_akx_GL = pickle.load( | |
open(os.path.join(os.path.join(path, "AKX_GL", "count_matched_quarks"), "result_props_dq.pkl"), "rb")) | |
except FileNotFoundError: | |
print("FileNotFoundError") | |
result_PR_AKX_GL = result_PR_AKX | |
#plot_only = ["LGATr_GP", "LGATr_GP_IRC_S", "LGATr_GP_IRC_SN", "LGATr_GP_50k", "LGATr_GP_IRC_S_50k"] | |
plot_only = [] | |
radius = [0.8] | |
def select_radius(d, radius, depth=3): | |
# from the dictionary, select radius at the level | |
if depth == 0: | |
return d[radius] | |
return {key: select_radius(d[key], radius, depth - 1) for key in d} | |
if len(models): # temporarily do not plot this one | |
#fig, ax = plt.subplots(3, len(plot_only) + len(radius)*2, figsize=(sz * (len(plot_only)+len(radius)*2), sz * 3)) | |
# three columns: PL, GL, scouting for each model | |
for i, model in tqdm(enumerate(sorted(models))): | |
output_path = os.path.join(path, model, "count_matched_quarks") | |
if not os.path.exists(os.path.join(output_path, "result.pkl")): | |
print("Result not exists for model", model) | |
continue | |
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) | |
#result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb")) | |
#result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) | |
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) | |
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
#matrix_plot(result, "Blues", "Avg. matched dark quarks / event").savefig(os.path.join(output_path, "avg_matched_dark_quarks.pdf"), ax=ax[0, i]) | |
#matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event").savefig(os.path.join(output_path, "avg_unmatched_jets.pdf"), ax=ax[1, i]) | |
#matrix_plot(result_PR, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0], ax=ax[0, i]) | |
#matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1], ax=ax[1, i]) | |
#matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax[2, i]) | |
print("Getting run config for model", model) | |
run_config_title, run_config = get_run_config(model) | |
print("RC title", run_config_title) | |
if run_config is None: | |
print("Skipping", model) | |
continue | |
#ax[0, i].set_title(run_config_title) | |
#ax[1, i].set_title(run_config_title) | |
#ax[2, i].set_title(run_config_title) | |
li = run_config["level_idx"] | |
#ax_f1[i, li].set_title(run_config_title) | |
#matrix_plot(result_PR, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax_f1[i, li]) | |
figures_all[run_config_title] = result_PR | |
print(model, run_config_title) | |
td, gtr, level, tdns = run_config["training_dataset"], run_config["GT_R"], run_config["level_idx"], run_config["training_dataset_nostep"] | |
if tdns in plot_only or not len(plot_only): | |
td = "R=" + str(gtr) + " " + td | |
if td not in figures_all_sorted: | |
figures_all_sorted[td] = {} | |
figures_all_sorted[td][level] = figures_all[run_config_title] | |
result_AKX_current = select_radius(result_PR_AKX, 0.8) | |
result_AKX_PL = select_radius(result_PR_AKX_PL, 0.8) | |
result_AKX_GL = select_radius(result_PR_AKX_GL, 0.8) | |
figures_all_sorted["AK8"]: { | |
0: result_AKX_PL, | |
1: result_AKX_current, | |
2: result_AKX_GL | |
} | |
for i, R in enumerate(radius): | |
result_PR_AKX_current = select_radius(result_PR_AKX, R) | |
#matrix_plot(result_PR_AKX_current, "Oranges", "Precision (N matched dark quarks / N predicted jets)", | |
# metric_comp_func=lambda r: r[0], ax=ax[0, i+len(models)]) | |
#matrix_plot(result_PR_AKX_current, "Reds", "Recall (N matched dark quarks / N dark quarks)", | |
# metric_comp_func=lambda r: r[1], ax=ax[1, i+len(models)]) | |
#matrix_plot(result_PR_AKX_current, "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), | |
# ax=ax[2, i+len(models)]) | |
#ax[0, i+len(models)].set_title(f"AK, R={R}") | |
#ax[1, i+len(models)].set_title(f"AK, R={R}") | |
#ax[2, i+len(models)].set_title(f"AK, R={R}") | |
t = f"AK, R={R}" | |
figures_all[t] = result_PR_AKX_current | |
for i, R in enumerate(radius): | |
result_PR_AKX_current = select_radius(result_PR_AKX_PL, R) | |
#matrix_plot(result_PR_AKX_current, "Oranges", "Precision (N matched dark quarks / N predicted jets)", | |
# metric_comp_func=lambda r: r[0], ax=ax[0, i+len(models)+len(radius)]) | |
#matrix_plot(result_PR_AKX_current, "Reds", "Recall (N matched dark quarks / N dark quarks)", | |
# metric_comp_func=lambda r: r[1], ax=ax[1, i+len(models)+len(radius)]) | |
#matrix_plot(result_PR_AKX_current, "Purples", r"$F_1$ score", metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), | |
# ax=ax[2, i+len(models)+len(radius)]) | |
#ax[0, i+len(models)+len(radius)].set_title(f"AK PL, R={R}") | |
#ax[1, i+len(models)+len(radius)].set_title(f"AK PL, R={R}") | |
#ax[2, i+len(models)+len(radius)].set_title(f"AK PL, R={R}") | |
figures_all[f"AK PL, R={R}"] = result_PR_AKX_current | |
for i, R in enumerate(radius): | |
result_PR_AKX_current = select_radius(result_PR_AKX_GL, R) | |
figures_all[f"AK GL, R={R}"] = result_PR_AKX_current | |
#fig.tight_layout() | |
#fig.savefig(out_file_PR) | |
#print("Saved to", out_file_PR) | |
#fig_f1.tight_layout().463 | |
#fig_f1.savefig(out_file_PRf1) | |
pickle.dump(figures_all, open(out_file_PR.replace(".pdf", ".pkl"), "wb")) | |
figures_all_sorted["AK8"] = { | |
0: select_radius(result_PR_AKX_PL, 0.8), | |
1: select_radius(result_PR_AKX, 0.8), | |
2: select_radius(result_PR_AKX_GL, 0.8) | |
} | |
text_level = ["PL", "PFCands", "GL"] | |
fig_f1, ax_f1 = plt.subplots(len(figures_all_sorted), 3, figsize=(sz * 2.5, sz * len(figures_all_sorted))) | |
if len(figures_all_sorted) == 1: | |
ax_f1 = np.array([ax_f1]) | |
for i in range(len(figures_all_sorted)): | |
model = list(figures_all_sorted.keys())[i] | |
renames = { | |
"R=0.8 base_LGATr_s50000": "LGATr", | |
"R=0.8 LGATr_GP_50k_s25020": "LGATr_GP", | |
"R=0.8 LGATr_GP_IRC_S_50k_s12900": "LGATr_GP_IRC_S", | |
"AK8": "AK8", | |
"R=0.8 LGATr_GP_IRC_SN_50k_s22020": "LGATr_GP_IRC_SN" | |
} | |
for j in range(3): | |
if j in figures_all_sorted[model]: | |
if j in figures_all_sorted[model]: | |
matrix_plot(figures_all_sorted[model][j], "Purples", r"$F_1$ score", | |
metric_comp_func=lambda r: 2 * r[0] * r[1] / (r[0] + r[1]), ax=ax_f1[i, j], is_qcd="qcd" in path.lower()) | |
ax_f1[i, j].set_title(renames.get(model, model) + " "+ text_level[j]) | |
ax_f1[i, j].set_xlabel("$m_{Z'}$") | |
ax_f1[i, j].set_ylabel("$r_{inv.}$") | |
fig_f1.tight_layout() | |
fig_f1.savefig(out_file_PRf1) | |
import pandas as pd | |
# plot QCD results: | |
def get_qcd_results(i): | |
# i=0: precision, i=1: recall, i=2: f1 score | |
qcd_results = {} | |
for model in figures_all_sorted: | |
qcd_results[model] = {} | |
for level in figures_all_sorted[model]: | |
r = figures_all_sorted[model][level][0][0][0] | |
r = [float(x) for x in r] # append f1 score | |
r.append(r[0]*2*r[1] / (r[0]+r[1])) | |
qcd_results[model][text_level[level]] = r[i] | |
return pd.DataFrame(qcd_results).T | |
if "qcd" in path.lower(): | |
print("Precision:") | |
print(get_qcd_results(0)) | |
print("----------------") | |
print("Recall:") | |
print(get_qcd_results(1)) | |
print("----------------") | |
print("F1 score:") | |
print(get_qcd_results(2)) | |
## Now do the GT R vs metrics plots | |
oranges = plt.get_cmap("Oranges") | |
reds = plt.get_cmap("Reds") | |
purples = plt.get_cmap("Purples") | |
mDark = 20 | |
if "qcd" in path.lower(): | |
print("QCD events") | |
mDark=0 | |
to_plot = {} # training dataset -> rInv -> mMed -> level -> "f1score" -> value | |
to_plot_steps = {} # training dataset -> rInv -> mMed -> level -> step -> value | |
to_plot_v2 = {} # level -> rInv -> mMed -> {"model": [P,R]} | |
quark_to_jet = {} # level -> rInv -> mMed -> model -> quark to jet assignment list | |
mc_gt_pt_of_dq = {} | |
pt_of_dq = {} | |
props_of_dq = {"eta": {}, "phi": {}} # Properties of dark quarks: eta and phi | |
results_all = {} | |
results_all_ak = {} | |
jet_properties = {} # training dataset -> rInv -> mMed -> level -> step -> jet property dict | |
jet_properties_ak = {} # rInv -> mMed -> level -> radius | |
plotting_hypotheses = [[700, 0.7], [700, 0.5], [700, 0.3], [900, 0.3], [900, 0.7]] | |
if "qcd" in path.lower(): | |
plotting_hypotheses = [[0,0]] | |
sz_small = 5 | |
for j, model in enumerate(models): | |
_, rc = get_run_config(model) | |
if rc is None or model in ["Eval_eval_19March2025_pt1e-2_500particles_FT_PL_2025_04_02_14_28_33_421FT", "Eval_eval_19March2025_pt1e-2_500particles_FT_PL_2025_04_02_14_47_23_671FT", "Eval_eval_19March2025_small_aug_vanishing_momentum_2025_03_28_11_45_16_582", "Eval_eval_19March2025_small_aug_vanishing_momentum_2025_03_28_11_46_26_326"]: | |
print("Skipping", model) | |
continue | |
td = rc["training_dataset"] | |
td_raw = rc["training_dataset_nostep"] | |
level = rc["level"] | |
r = rc["GT_R"] | |
output_path = os.path.join(path, model, "count_matched_quarks") | |
if not os.path.exists(os.path.join(output_path, "result_PR.pkl")): | |
continue | |
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
result_QJ = pickle.load(open(os.path.join(output_path, "result_quark_to_jet.pkl"), "rb")) | |
result_jet_props = pickle.load(open(os.path.join(output_path, "result_jet_properties.pkl"), "rb")) | |
result_MC_PT = pickle.load(open(os.path.join(output_path, "result_pt_mc_gt.pkl"), "rb")) | |
result_PT_DQ = pickle.load(open(os.path.join(output_path, "result_pt_dq.pkl"), "rb")) | |
result_DQ_props = pickle.load(open(os.path.join(output_path, "result_props_dq.pkl"), "rb")) | |
print(level) | |
if td not in to_plot: | |
to_plot[td] = {} | |
results_all[td] = {} | |
if td_raw not in to_plot_steps: | |
to_plot_steps[td_raw] = {} | |
jet_properties[td_raw] = {} | |
level_idx = ["PL", "scouting", "GL"].index(level) | |
if level_idx not in to_plot_v2: | |
to_plot_v2[level_idx] = {} | |
quark_to_jet[level_idx] = {} | |
pt_of_dq[level_idx] = {} | |
mc_gt_pt_of_dq[level_idx] = {} | |
for prop in props_of_dq: | |
props_of_dq[prop][level_idx] = {} | |
for mMed_h in result_PR: | |
if mMed_h not in to_plot_v2[level_idx]: | |
to_plot_v2[level_idx][mMed_h] = {} | |
quark_to_jet[level_idx][mMed_h] = {} | |
pt_of_dq[level_idx][mMed_h] = {} | |
mc_gt_pt_of_dq[level_idx][mMed_h] = {} | |
for prop in props_of_dq: | |
props_of_dq[prop][level_idx][mMed_h] = {} | |
if mMed_h not in to_plot_steps[td_raw]: | |
to_plot_steps[td_raw][mMed_h] = {} | |
jet_properties[td_raw][mMed_h] = {} | |
if mMed_h not in results_all[td]: | |
results_all[td][mMed_h] = {mDark: {}} | |
for rInv_h in result_PR[mMed_h][mDark]: | |
if rInv_h not in to_plot_v2[level_idx][mMed_h]: | |
to_plot_v2[level_idx][mMed_h][rInv_h] = {} | |
quark_to_jet[level_idx][mMed_h][rInv_h] = {} | |
pt_of_dq[level_idx][mMed_h][rInv_h] = {} | |
mc_gt_pt_of_dq[level_idx][mMed_h][rInv_h] = {} | |
for prop in props_of_dq: | |
props_of_dq[prop][level_idx][mMed_h][rInv_h] = {} | |
if rInv_h not in to_plot_steps[td_raw][mMed_h]: | |
to_plot_steps[td_raw][mMed_h][rInv_h] = {} | |
jet_properties[td_raw][mMed_h][rInv_h] = {} | |
if level not in to_plot_steps[td_raw][mMed_h][rInv_h]: | |
to_plot_steps[td_raw][mMed_h][rInv_h][level] = {} | |
jet_properties[td_raw][mMed_h][rInv_h][level] = {} | |
if rInv_h not in results_all[td][mMed_h][mDark]: | |
results_all[td][mMed_h][mDark][rInv_h] = {} | |
#for level in ["PL+ghosts", "GL+ghosts", "scouting+ghosts"]: | |
if level not in results_all[td][mMed_h][mDark][rInv_h]: | |
results_all[td][mMed_h][mDark][rInv_h][level] = {} | |
precision = result_PR[mMed_h][mDark][rInv_h][0] | |
recall = result_PR[mMed_h][mDark][rInv_h][1] | |
f1score = 2 * precision * recall / (precision + recall) | |
to_plot_v2[level_idx][mMed_h][rInv_h][td_raw] = [precision, recall] | |
quark_to_jet[level_idx][mMed_h][rInv_h][td_raw] = result_QJ[mMed_h][mDark][rInv_h] | |
pt_of_dq[level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_PT_DQ[mMed_h][mDark][rInv_h]) | |
mc_gt_pt_of_dq[level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_MC_PT[mMed_h][mDark][rInv_h]) | |
for prop in props_of_dq: | |
props_of_dq[prop][level_idx][mMed_h][rInv_h][td_raw] = flatten_list(result_DQ_props[prop][mMed_h][mDark][rInv_h]) | |
#print("qj", quark_to_jet[level_idx][mMed_h][rInv_h][td_raw]) | |
if r not in results_all[td][mMed_h][mDark][rInv_h][level]: | |
results_all[td][mMed_h][mDark][rInv_h][level][r] = f1score | |
ckpt_step = rc["ckpt_step"] | |
to_plot_steps[td_raw][mMed_h][rInv_h][level][ckpt_step] = f1score | |
jet_properties[td_raw][mMed_h][rInv_h][level][ckpt_step] = result_jet_props[mMed_h][mDark][rInv_h] | |
m_Meds = [] | |
r_invs = [] | |
for key in to_plot_steps: | |
m_Meds += list(to_plot_steps[key].keys()) | |
for key2 in to_plot_steps[key]: | |
r_invs += list(to_plot_steps[key][key2].keys()) | |
m_Meds = sorted(list(set(m_Meds))) | |
r_invs = sorted(list(set(r_invs))) | |
result_AKX_current = select_radius(result_PR_AKX, 0.8) | |
result_AKX_PL = select_radius(result_PR_AKX_PL, 0.8) | |
result_AKX_GL = select_radius(result_PR_AKX_GL, 0.8) | |
result_AKX_jet_properties = select_radius(result_jet_props_akx, 0.8) | |
jet_properties["AK8"] = {} | |
result_AKX_current_QJ = select_radius(result_qj_akx, 0.8) | |
result_AKX_PL_QJ = select_radius(result_qj_akx_PL, 0.8) | |
result_AKX_GL_QJ = select_radius(result_qj_akx_GL, 0.8) | |
result_AKX_current_pt_dq = select_radius(result_dq_pt_akx, 0.8) | |
result_AKX_PL_pt_dq = select_radius(result_dq_pt_akx_PL, 0.8) | |
result_AKX_GL_pt_dq = select_radius(result_dq_pt_akx_GL, 0.8) | |
result_AKX_current_MCpt_dq = select_radius(result_dq_mc_pt_akx, 0.8) | |
result_AKX_PL_MCpt_dq = select_radius(result_dq_mc_pt_akx_PL, 0.8) | |
result_AKX_GL_MCpt_dq = select_radius(result_dq_mc_pt_akx_GL, 0.8) | |
result_AKX_current_props_dq = select_radius(result_dq_props_akx, 0.8, depth=4) | |
result_AKX_PL_props_dq = select_radius(result_dq_props_akx_PL, 0.8, depth=4) | |
result_AKX_GL_props_dq = select_radius(result_dq_props_akx_GL, 0.8, depth=4) | |
from tqdm import tqdm | |
for mMed_h in result_AKX_jet_properties: | |
for rInv_h in result_AKX_jet_properties[mMed_h][mDark]: | |
if 0 in to_plot_v2: | |
to_plot_v2[0][mMed_h][rInv_h]["AK8"] = result_AKX_PL[mMed_h][mDark][rInv_h] | |
to_plot_v2[1][mMed_h][rInv_h]["AK8"] = result_AKX_current[mMed_h][mDark][rInv_h] | |
to_plot_v2[2][mMed_h][rInv_h]["AK8"] = result_AKX_GL[mMed_h][mDark][rInv_h] | |
quark_to_jet[0][mMed_h][rInv_h]["AK8"] = result_AKX_PL_QJ[mMed_h][mDark][rInv_h] | |
quark_to_jet[1][mMed_h][rInv_h]["AK8"] = result_AKX_current_QJ[mMed_h][mDark][rInv_h] | |
quark_to_jet[2][mMed_h][rInv_h]["AK8"] = result_AKX_GL_QJ[mMed_h][mDark][rInv_h] | |
pt_of_dq[0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_pt_dq[mMed_h][mDark][rInv_h]) | |
pt_of_dq[1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_pt_dq[mMed_h][mDark][rInv_h]) | |
pt_of_dq[2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_pt_dq[mMed_h][mDark][rInv_h]) | |
mc_gt_pt_of_dq[0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_MCpt_dq[mMed_h][mDark][rInv_h]) | |
mc_gt_pt_of_dq[1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_MCpt_dq[mMed_h][mDark][rInv_h]) | |
mc_gt_pt_of_dq[2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_MCpt_dq[mMed_h][mDark][rInv_h]) | |
for k in props_of_dq: | |
props_of_dq[k][0][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_PL_props_dq[k][mMed_h][mDark][rInv_h]) | |
props_of_dq[k][1][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_current_props_dq[k][mMed_h][mDark][rInv_h]) | |
props_of_dq[k][2][mMed_h][rInv_h]["AK8"] = flatten_list(result_AKX_GL_props_dq[k][mMed_h][mDark][rInv_h]) | |
if mMed_h not in jet_properties["AK8"]: | |
jet_properties["AK8"][mMed_h] = {} | |
if rInv_h not in jet_properties["AK8"][mMed_h]: | |
jet_properties["AK8"][mMed_h][rInv_h] = {} | |
jet_properties["AK8"][mMed_h][rInv_h] = {"scouting": {50000: result_AKX_jet_properties[mMed_h][mDark][rInv_h]}} | |
rename_results_dict = { | |
"LGATr_comparison_DifferentTrainingDS": "base", | |
"LGATr_comparison_GP_training": "GP", | |
"LGATr_comparison_GP_IRC_S_training": "GP_IRC_S", | |
"LGATr_comparison_GP_IRC_SN_training": "GP_IRC_SN" | |
} | |
hypotheses_to_plot = [[0,0],[700,0.7],[700,0.5],[700,0.3]] | |
def powerset(iterable): | |
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" | |
s = list(iterable) | |
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) | |
def get_label_from_superset(lbl, labels_rename, labels): | |
if lbl == '': | |
return "Missed by all" | |
r = [labels[int(i)] for i in lbl] | |
r = [labels_rename.get(l,l) for l in r] | |
if len(r) == 2 and "QCD" in r and "900_03" in r: | |
return "Found by both models but not AK" | |
if len(r) == 3: | |
return "Found by all" | |
return ", ".join(r) | |
for hyp_m, hyp_rinv in hypotheses_to_plot: | |
if 0 not in to_plot_v2: | |
continue # Not for the lower-pt thresholds, where only GL and PL are available | |
if hyp_m not in to_plot_v2[0] or hyp_rinv not in to_plot_v2[0][hyp_m]: | |
continue | |
# plot here the venn diagrams | |
labels = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] | |
labels_global = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] | |
labels_rename = {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_50k": "900_03"} | |
fig_venn, ax_venn = plt.subplots(6, 3, figsize=(5 * 3, 5 * 6)) # the bottom ones are for pt of the DQ, pt of the MC GT, pt of MC GT / pt of DQ, eta, and phi distributions | |
fig_venn1, ax_venn1 = plt.subplots(6, 2, figsize=(5*2, 5*6)) # Only the PFCands-level, with full histogram on the left and density on the right | |
for level in range(3): | |
#labels = list(results_dict["LGATr_comparison_GP_IRC_S_training"][0].keys()) | |
label_combination_to_number = {} # fill it with all possible label combinations e.g. if there are 3 labels: "NA", "0", "1", "2", "01", "012", "12", "02" | |
powerset_str = ["".join([str(x) for x in sorted(list(a))]) for a in powerset(range(len(labels)))] | |
set_to_count = {key: 0 for key in powerset_str} | |
set_to_stats = {key: {"pt_dq": [], "pt_mc_t": [], "pt_mc_t_dq_ratio": [], "eta": [], "phi": []} for key in powerset_str} | |
label_to_result = {} | |
#label_to_stats = {"pt_dq": , "pt_mc_t": [], "pt_mc_t_dq_ratio": [], "eta": [], "phi": []} | |
n_dq = 999999999 | |
for j, label in enumerate(labels): | |
r = flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label]) | |
n_dq = min(n_dq, len(r)) # Find the minimum number of dark quarks in all labels | |
for j, label in enumerate(labels): | |
r = torch.tensor(flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label])) | |
r = (r != -1) # Whether quark no. X is caught or not | |
label_to_result[j] = r.tolist()[:n_dq] | |
#r = torch.tensor(flatten_list(pt_of_dq[level][hyp_m][hyp_rinv][label])) | |
#r = r[:n_dq] | |
#label_to_stats["pt_dq"].append(r.tolist()) | |
#r1 = torch.tensor(flatten_list(mc_gt_pt_of_dq[level][hyp_m][hyp_rinv][label])) | |
#r1 = r1[:n_dq] | |
#label_to_stats["pt_mc_t"].append(r1.tolist()) | |
#r2 = r1 / r | |
#r2 = r2[:n_dq] | |
#label_to_stats["pt_mc_t_dq_ratio"].append(r2.tolist()) | |
#r_eta = torch.tensor(flatten_list(props_of_dq["eta"][level][hyp_m][hyp_rinv][label])) | |
#r_eta = r_eta[:n_dq] | |
#label_to_stats["eta"].append(r_eta.tolist()) | |
##r_phi = torch.tensor(flatten_list(props_of_dq["phi"][level][hyp_m][hyp_rinv][label])) | |
#r_phi = r_phi[:n_dq] | |
#label_to_stats["phi"].append(r_phi.tolist()) | |
assert len(label_to_result[j]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[j])}" | |
#n_dq = min(n_dq, len(r)) | |
#for j, label in enumerate(labels): | |
# assert len(label_to_result[j]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[j])}" | |
for c in tqdm(range(n_dq)): | |
belonging_to_set = "" | |
for j, label in enumerate(labels): | |
if label_to_result[j][c] == 1: | |
belonging_to_set += str(j) | |
set_to_count[belonging_to_set] += 1 | |
#for key in label_to_stats: | |
# for idx in belonging_to_set: | |
# idx_int = int(idx) # e.g. "0", "1" etc. | |
# set_to_stats[belonging_to_set] | |
for j, label in enumerate(labels): | |
current_dq_pt = pt_of_dq[level][hyp_m][hyp_rinv][label][c] | |
current_mc_gt_pt = mc_gt_pt_of_dq[level][hyp_m][hyp_rinv][label][c] | |
current_dq_eta = props_of_dq["eta"][level][hyp_m][hyp_rinv][label][c] | |
current_dq_phi = props_of_dq["phi"][level][hyp_m][hyp_rinv][label][c] | |
set_to_stats[belonging_to_set]["pt_dq"].append(current_dq_pt) | |
set_to_stats[belonging_to_set]["pt_mc_t"].append(current_mc_gt_pt) | |
set_to_stats[belonging_to_set]["pt_mc_t_dq_ratio"].append(current_mc_gt_pt/current_dq_pt) | |
set_to_stats[belonging_to_set]["eta"].append(current_dq_eta) | |
set_to_stats[belonging_to_set]["phi"].append(current_dq_phi) | |
#print("set_to_count for level", level, ":", set_to_count, "labels:", labels) | |
title = f"$m_{{Z'}}={hyp_m}$ GeV, $r_{{inv.}}={hyp_rinv}$, {text_level[level]} (missed by all: {set_to_count['']}) " | |
if hyp_m == 0 and hyp_rinv == 0: | |
title = f"QCD, {text_level[level]} (missed by all: {set_to_count['']})" | |
ax_venn[0, level].set_title(title) | |
plot_venn3_from_index_dict(ax_venn[0, level], set_to_count, set_labels=[labels_rename.get(l,l) for l in labels], set_colors=["orange", "gray", "red"]) | |
if level == 1: #reco-level | |
plot_venn3_from_index_dict(ax_venn1[0, 1], set_to_count, | |
set_labels=[labels_rename.get(l,l) for l in labels], | |
set_colors=["orange", "gray", "red"]) | |
bins = { | |
"pt_dq": np.linspace(90, 250, 50), | |
"pt_mc_t": np.linspace(0, 200, 50), | |
"pt_mc_t_dq_ratio": np.linspace(0, 1.3, 30), | |
"eta": np.linspace(-4, 4, 20), | |
"phi": np.linspace(-np.pi, np.pi, 20) | |
} | |
# 10 random colors | |
clrs = ["green", "red", "orange", "pink", "blue", "purple", "cyan", "magenta"] | |
key_rename_dict = {"pt_dq": "$p_T$ of quark", "pt_mc_t": "$p_T$ of particles within radius of R=0.8 of quark", "pt_mc_t_dq_ratio": "$p_T$ (part. within R=0.8 of quark) / $p_T$ (quark) ", "eta": "$\eta$ of quark", "phi": "$\phi$ of quark" } | |
for k, key in enumerate(["pt_dq", "pt_mc_t", "pt_mc_t_dq_ratio", "eta", "phi"]): | |
for s_idx, s in enumerate(sorted(set_to_stats.keys())): | |
if len(set_to_stats[s][key]) == 0: | |
continue | |
lbl = s | |
#if s == "": | |
# lbl = "none" | |
lbl1 = get_label_from_superset(lbl, labels_rename, labels) | |
if lbl1 not in ["Missed by all", "Found by both models but not AK", "AK8", "Found by all"]: | |
continue | |
if level == 1: | |
ax_venn1[k + 1, 1].hist(set_to_stats[s][key], bins=bins[key], histtype="step", | |
label=lbl1, color=clrs[s_idx], density=True) | |
ax_venn1[k + 1, 0].set_title(f"{key_rename_dict[key]}") | |
ax_venn1[k+1, 1].set_title(f"{key_rename_dict[key]}") | |
ax_venn1[k + 1, 1].set_ylabel("Density") | |
if lbl not in ["", "012"]: | |
# We are only interested in the differences... | |
ax_venn[k+1, level].hist(set_to_stats[s][key], bins=bins[key], histtype="step", label=lbl1, color=clrs[s_idx]) | |
ax_venn[k+1, level].set_title(f"{key_rename_dict[key]}") | |
if level == 1: | |
ax_venn1[k + 1, 0].hist(set_to_stats[s][key], bins=bins[key], histtype="step", | |
label=lbl1, | |
color=clrs[s_idx]) | |
#ax_venn[k+1, level].set_xlabel(key) | |
#ax_venn[k+1, level].set_ylabel("Count") | |
for k in range(5): | |
ax_venn[k+1, level].legend() | |
fig_venn.tight_layout() | |
for k in range(5): | |
ax_venn1[k+1, 0].legend() | |
ax_venn1[k+1, 1].legend() | |
fig_venn1.tight_layout() | |
f = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}.pdf") | |
fig_venn.savefig(f) | |
f1 = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}_reco_level_only.pdf") | |
fig_venn1.savefig(f1) | |
for i, lbl in enumerate(["precision", "recall", "F1"]): # 0=precision, 1=recall, 2=F1 | |
sz_small1 = 2.5 | |
fig, ax = plt.subplots(len(rename_results_dict), 3, figsize=(sz_small1 * 3, sz_small1 * len(rename_results_dict))) | |
for i1, key in enumerate(list(rename_results_dict.keys())): | |
for level in range(3): | |
level_text = text_level[level] | |
labels = list(results_dict[key][0].keys()) | |
colors = [results_dict[key][0][l] for l in labels] | |
res_precision = np.array([to_plot_v2[level][hyp_m][hyp_rinv][l][0] for l in labels]) | |
res_recall = np.array([to_plot_v2[level][hyp_m][hyp_rinv][l][1] for l in labels]) | |
res_f1 = 2 * res_precision * res_recall / (res_precision + res_recall) | |
if i == 0: | |
values = res_precision | |
elif i == 1: | |
values = res_recall | |
else: | |
values = res_f1 | |
rename_dict = results_dict[key][1] | |
labels_renamed = [rename_dict.get(l,l) for l in labels] | |
print(i1, level) | |
ax_tiny_histogram(ax[i1, level], labels_renamed, colors, values) | |
ax[i1, level].set_title(f"{rename_results_dict[key]} {level_text}") | |
fig.tight_layout() | |
fig.savefig(os.path.join(get_path(args.input, "results"), f"{lbl}_results_by_level_{hyp_m}_{hyp_rinv}_{key}.pdf")) | |
for hyp_m, hyp_rinv in hypotheses_to_plot: | |
if 0 not in to_plot_v2: | |
continue # Not for the lower-pt thresholds, where only GL and PL are available | |
if hyp_m not in to_plot_v2[0] or hyp_rinv not in to_plot_v2[0][hyp_m]: | |
continue | |
# plot here the venn diagrams | |
labels = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] | |
labels_global = ["LGATr_GP_IRC_S_QCD", "AK8", "LGATr_GP_IRC_S_50k"] | |
labels_rename = {"LGATr_GP_IRC_S_QCD": "QCD", "LGATr_GP_IRC_S_50k": "900_03"} | |
fig_venn2, ax_venn2 = plt.subplots(1, len(labels), figsize=(4*len(labels), 4)) # the bottom ones are for pt of the DQ, pt of the MC GT, pt of MC GT / pt of DQ, eta, and phi distributions | |
for j, label in enumerate(labels): | |
#labels = list(results_dict["LGATr_comparison_GP_IRC_S_training"][0].keys()) | |
label_combination_to_number = {} # fill it with all possible label combinations e.g. if there are 3 labels: "NA", "0", "1", "2", "01", "012", "12", "02" | |
powerset_str = ["".join([str(x) for x in sorted(list(a))]) for a in powerset(range(3))] | |
set_to_count = {key: 0 for key in powerset_str} | |
label_to_result = {} | |
n_dq = 99999999 # Sometimes, the last batch gets cut off etc. ... | |
for level in range(3): | |
r = flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label]) | |
n_dq = min(n_dq, len(r)) | |
for level in range(3): | |
r = torch.tensor(flatten_list(quark_to_jet[level][hyp_m][hyp_rinv][label])) | |
r = (r != -1) | |
label_to_result[level] = r.tolist()[:n_dq] | |
assert len(label_to_result[level]) == n_dq, f"Label {label} has different number of quarks than others {n_dq} != {len(label_to_result[level])}" | |
for c in tqdm(range(n_dq)): | |
belonging_to_set = "" | |
for lvl in range(3): | |
if label_to_result[lvl][c] == 1: | |
belonging_to_set += str(lvl) | |
set_to_count[belonging_to_set] += 1 | |
if hyp_m == 0 and hyp_rinv == 0: | |
title = f"QCD, {label} (missed by all: {set_to_count['']}) " | |
else: | |
title = f"$m_{{Z'}}={hyp_m}$ GeV, $r_{{inv.}}={hyp_rinv}$, {label} (miss: {set_to_count['']}) " | |
ax_venn2[j].set_title(title) | |
plot_venn3_from_index_dict(ax_venn2[j], set_to_count, set_labels=text_level, set_colors=["orange", "gray", "red"], remove_max=1) | |
fig_venn2.tight_layout() | |
f = os.path.join(get_path(args.input, "results"), f"venn_diagram_{hyp_m}_{hyp_rinv}_Agreement_between_levels.pdf") | |
fig_venn2.savefig(f) | |
for key in results_dict: | |
for level in range(3): | |
level_text = text_level[level] | |
labels = list(results_dict[key][0].keys()) | |
if level in to_plot_v2: | |
f, a = multiple_matrix_plot(to_plot_v2[level], labels=labels, colors=[results_dict[key][0][l] for l in labels], rename_dict=results_dict[key][1]) | |
if f is None: | |
print("No figure for", key, level) | |
continue | |
#f.suptitle(f"{level_text} $F_1$ score") | |
out_file = f"grid_stack_F1_{level_text}_{key}.pdf" | |
out_file = os.path.join(get_path(args.input, "results"), out_file) | |
f.savefig(out_file) | |
print("Saved to", out_file) | |
from matplotlib.lines import Line2D | |
# Define custom legend handles | |
custom_lines = [ | |
Line2D([0], [0], color='orange', linestyle='-', label='LGATr'), | |
Line2D([0], [0], color='green', linestyle='-', label='GATr'), | |
Line2D([0], [0], color='blue', linestyle='-', label='Transformer'), | |
Line2D([0], [0], color='gray', linestyle='-', label='AK8'), | |
Line2D([0], [0], color='black', linestyle='-', label='reco'), | |
Line2D([0], [0], color='black', linestyle=':', label='gen'), | |
Line2D([0], [0], color='black', linestyle='--', label='parton'), | |
] | |
if len(models): | |
fig_steps, ax_steps = plt.subplots(len(m_Meds), len(r_invs), figsize=(sz_small * len(r_invs), sz_small * len(m_Meds))) | |
if len(m_Meds) == 1 and len(r_invs) == 1: | |
ax_steps = np.array([[ax_steps]]) | |
histograms = {} | |
for key in histograms_dict: | |
if key not in histograms: | |
histograms[key] = {} | |
for i in ["pt", "eta", "phi"]: | |
f, a = plt.subplots(len(m_Meds), len(r_invs), figsize=(sz_small * len(r_invs), sz_small * len(m_Meds))) | |
if len(r_invs) == 1 and len(m_Meds) == 1: | |
a = np.array([[a]]) | |
histograms[key][i] = f, a | |
colors = {"base_LGATr": "orange", "base_Tr": "blue", "base_GATr": "green", "AK8": "gray"} # THE COLORS FOR THE STEP VS. F1 SCORE | |
#colors_small_dataset = {"base_LGATr_SD": "orange", "base_Tr_SD": "blue", "base_GATr_SD": "green", "AK8": "gray"} | |
#colors = colors_small_dataset | |
level_styles = {"scouting": "solid", "PL": "dashed", "GL": "dotted"} | |
#step_to_plot_histograms = 50000 # phi, eta, pt histograms... | |
level_to_plot_histograms = "scouting" | |
for i, mMed_h in enumerate(m_Meds): | |
for j, rInv_h in enumerate(r_invs): | |
ax_steps[i, j].set_title("$m_{{Z'}} = {}$ GeV, $r_{{inv.}} = {}$".format(mMed_h, rInv_h)) | |
ax_steps[i, j].set_xlabel("Training step") | |
ax_steps[i, j].set_ylabel("Test $F_1$ score") | |
#if j == 0: | |
#ax_steps[i, j].set_ylabel("$m_{{Z'}} = {}$".format(mMed_h)) | |
#for subset in histograms: | |
#for key in histograms[subset]: | |
#histograms[subset][key][1][i, j].set_ylabel("$m_{{Z'}} = {}$".format(mMed_h)) | |
if i == len(m_Meds)-1: | |
ax_steps[i, j].set_xlabel("$r_{{inv.}} = {}$".format(rInv_h)) | |
for subset in histograms: | |
for key in histograms[subset]: | |
histograms[subset][key][1][i, j].set_xlabel("$r_{{inv.}} = {}$".format(rInv_h)) | |
for model in jet_properties: | |
if level_to_plot_histograms not in jet_properties[model][mMed_h][rInv_h]: | |
print("Skipping", model, level_to_plot_histograms, " - levels:", jet_properties[model][mMed_h][rInv_h].keys()) | |
continue | |
for subset in histograms: | |
for key in histograms[subset]: | |
if model not in histograms_dict[subset][1]: | |
continue | |
step_to_plot_histograms = histograms_dict[subset][0][model] | |
if step_to_plot_histograms not in jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms]: | |
print("Swapping the step to plot histograms", jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms].keys()) | |
step_to_plot_histograms = sorted(list(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms].keys()))[0] | |
pred = np.array(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms][step_to_plot_histograms][key + "_pred"]) | |
truth = np.array(jet_properties[model][mMed_h][rInv_h][level_to_plot_histograms][step_to_plot_histograms][key + "_gen_particle"]) | |
if key.startswith("pt"): | |
q = pred/truth | |
symbol = "/" # division instead of subtraction symbol for pt | |
quantity = "p_{T,pred}/p_{T,true}" | |
bins = np.linspace(0, 2.5, 100) | |
elif key.startswith("eta"): | |
q = (pred - truth) | |
symbol = "-" | |
quantity="\eta_{pred}-\eta_{true}" | |
bins = np.linspace(-0.8, 0.8, 50) | |
elif key.startswith("phi"): | |
q = pred - truth | |
symbol = "-" | |
quantity = "\phi_{pred}-\phi_{true}" | |
q[q > np.pi] -= 2 * np.pi | |
q[q< -np.pi] += 2 * np.pi | |
bins = np.linspace(-0.8, 0.8, 50) | |
print("Max", np.max(q), "Min", np.min(q)) | |
rename = {"base_LGATr": "LGATr", | |
"LGATr_GP_IRC_S_50k": "LGATr_GP_IRC_S", | |
"AK8": "AK8", | |
"LGATr_GP_50k": "LGATr_GP"} | |
histograms[subset][key][1][i, j].hist(q, histtype="step", color=histograms_dict[subset][1][model], label=rename.get(model, model), bins=bins, density=True) | |
if mMed_h > 0: | |
histograms[subset][key][1][i, j].set_title(f"${quantity}$ $m_{{Z'}}={mMed_h}$ GeV, $r_{{inv.}}={rInv_h}$") | |
else: | |
histograms[subset][key][1][i, j].set_title(f"${quantity}$") | |
histograms[subset][key][1][i, j].legend() | |
histograms[subset][key][1][i, j].grid(True) | |
for model in to_plot_steps: | |
for lvl in to_plot_steps[model][mMed_h][rInv_h]: | |
if model not in colors: | |
print("Skipping", model) | |
continue | |
print(model) | |
ls = level_styles[lvl] | |
plt_dict = to_plot_steps[model][mMed_h][rInv_h][lvl] | |
x_pts = sorted(list(plt_dict.keys())) | |
y_pts = [plt_dict[k] for k in x_pts] | |
if ls == "solid": | |
ax_steps[i, j].plot(x_pts, y_pts, label=model, marker=".", linestyle=ls, color=colors[model]) | |
else: | |
# No label | |
ax_steps[i, j].plot(x_pts, y_pts, marker=".", linestyle=ls, color=colors[model]) | |
ax_steps[i, j].legend(handles=custom_lines) | |
# now plot a horizontal line for the AKX same level | |
if lvl == "scouting": | |
rc = result_AKX_current | |
elif lvl == "PL": | |
rc = result_AKX_PL | |
elif lvl == "GL": | |
rc = result_AKX_GL | |
else: | |
raise Exception | |
pr = rc[mMed_h][mDark][rInv_h][0] | |
rec = rc[mMed_h][mDark][rInv_h][1] | |
f1ak = 2 * pr * rec / (pr + rec) | |
ax_steps[i, j].axhline(f1ak, color="gray", linestyle=ls, alpha=0.5) | |
ax_steps[i, j].grid(1) | |
path_steps_fig = os.path.join(get_path(args.input, "results"), "score_vs_step_plots.pdf") | |
fig_steps.tight_layout() | |
fig_steps.savefig(path_steps_fig) | |
for subset in histograms: | |
for key in histograms[subset]: | |
fig = histograms[subset][key][0] | |
fig.tight_layout() | |
fig.savefig(os.path.join(get_path(args.input, "results"), "histogram_{}_{}.pdf".format(key, subset))) | |
print("Saved to", path_steps_fig) | |
'''for i, h in enumerate(plotting_hypotheses): | |
mMed_h, rInv_h = h | |
if rInv_h not in to_plot[td]: | |
to_plot[td][rInv_h] = {} | |
print("Model", model) | |
if mMed_h not in to_plot[td][rInv_h]: | |
to_plot[td][rInv_h][mMed_h] = {} # level | |
if level not in to_plot[td][rInv_h][mMed_h]: | |
to_plot[td][rInv_h][mMed_h][level] = {"precision": [], "recall": [], "f1score": [], "R": []} | |
precision = result_PR[mMed_h][mDark][rInv_h][0] | |
recall = result_PR[mMed_h][mDark][rInv_h][1] | |
f1score = 2 * precision * recall / (precision + recall) | |
to_plot[td][rInv_h][mMed_h][level]["precision"].append(precision) | |
to_plot[td][rInv_h][mMed_h][level]["recall"].append(recall) | |
to_plot[td][rInv_h][mMed_h][level]["f1score"].append(f1score) | |
to_plot[td][rInv_h][mMed_h][level]["R"].append(r) | |
''' | |
to_plot_ak = {} # level ("scouting"/"GL"/"PL") -> rInv -> mMed -> {"f1score": [], "R": []} | |
for j, model in enumerate(["AKX", "AKX_PL", "AKX_GL"]): | |
print(model) | |
if os.path.exists(os.path.join(path, model, "count_matched_quarks", "result_PR_AKX.pkl")): | |
result_PR_AKX = pickle.load(open(os.path.join(path, model, "count_matched_quarks", "result_PR_AKX.pkl"), "rb")) | |
else: | |
print("Skipping", model) | |
continue | |
level = "scouting" | |
if "PL" in model: | |
level = "PL" | |
elif "GL" in model: | |
level = "GL" | |
if level not in to_plot_ak: | |
to_plot_ak[level] = {} | |
for mMed_h in result_PR_AKX: | |
if mMed_h not in results_all_ak: | |
results_all_ak[mMed_h] = {mDark: {}} | |
for rInv_h in result_PR_AKX[mMed_h][mDark]: | |
if rInv_h not in results_all_ak[mMed_h][mDark]: | |
results_all_ak[mMed_h][mDark][rInv_h] = {} | |
if level not in results_all_ak[mMed_h][mDark][rInv_h]: | |
results_all_ak[mMed_h][mDark][rInv_h][level] = {} | |
for ridx, R in enumerate(result_PR_AKX[mMed_h][mDark][rInv_h]): | |
if R not in results_all_ak[mMed_h][mDark][rInv_h][level]: | |
precision = result_PR_AKX[mMed_h][mDark][rInv_h][R][0] | |
recall = result_PR_AKX[mMed_h][mDark][rInv_h][R][1] | |
f1score = 2 * precision * recall / (precision + recall) | |
results_all_ak[mMed_h][mDark][rInv_h][level][R] = f1score | |
for i, h in enumerate(plotting_hypotheses): | |
mMed_h, rInv_h = h | |
if rInv_h not in to_plot_ak[level]: | |
to_plot_ak[level][rInv_h] = {} | |
print("Model", model) | |
if mMed_h not in to_plot_ak[level][rInv_h]: | |
to_plot_ak[level][rInv_h][mMed_h] = {"precision": [], "recall": [], "f1score": [], "R": []} | |
rs = sorted(result_PR_AKX[mMed_h][mDark][rInv_h].keys()) | |
precision = np.array([result_PR_AKX[mMed_h][mDark][rInv_h][i][0] for i in rs]) | |
recall = np.array([result_PR_AKX[mMed_h][mDark][rInv_h][i][1] for i in rs]) | |
f1score = 2 * precision * recall / (precision + recall) | |
to_plot_ak[level][rInv_h][mMed_h]["precision"] = precision | |
to_plot_ak[level][rInv_h][mMed_h]["recall"] = recall | |
to_plot_ak[level][rInv_h][mMed_h]["f1score"] = f1score | |
to_plot_ak[level][rInv_h][mMed_h]["R"] = rs | |
print("AK:", to_plot_ak) | |
fig, ax = plt.subplots(len(to_plot) + 1, len(plotting_hypotheses), figsize=(sz_small * len(plotting_hypotheses), sz_small * len(to_plot))) # also add AKX as last plot | |
if len(to_plot) == 0: | |
ax = np.array([ax]) | |
colors = { | |
#"PL": "green", | |
#"GL": "blue", | |
#"scouting": "red", | |
"PL+ghosts": "green", | |
"GL+ghosts": "blue", | |
"scouting+ghosts": "red" | |
} | |
ak_colors = { | |
"PL": "green", | |
"GL": "blue", | |
"scouting": "red", | |
} | |
''' | |
for i, td in enumerate(to_plot): | |
# for each training dataset | |
for j, h in enumerate(plotting_hypotheses): | |
ax[i, j].set_title(f"r_inv={h[1]}, m={h[0]}, tr. on {td}") | |
ax[i, j].set_ylabel("F1 score") | |
ax[i, j].set_xlabel("GT R") | |
ax[i, j].grid() | |
for level in sorted(list(to_plot[td][h[1]][h[0]].keys())): | |
print("level", level) | |
print("Plotting", td, h[1], h[0], level) | |
if level in colors: | |
ax[i, j].plot(to_plot[td][h[1]][h[0]][level]["R"], to_plot[td][h[1]][h[0]][level]["f1score"], ".-", label=level, color=colors[level]) | |
ax[i, j].legend() | |
for j, h in enumerate(plotting_hypotheses): # for to_plot_AK | |
ax[-1, j].set_title(f"r_inv={h[1]}, m={h[0]}, AK baseline") | |
ax[-1, j].set_ylabel("F1 score") | |
ax[-1, j].set_xlabel("GT R") | |
ax[-1, j].grid() | |
for i, ak_level in enumerate(sorted(list(to_plot_ak.keys()))): | |
mMed_h, rInv_h = h | |
if ak_level in ak_colors: | |
ax[-1, j].plot(to_plot_ak[ak_level][rInv_h][mMed_h]["R"], to_plot_ak[ak_level][rInv_h][mMed_h]["f1score"], ".-", label=ak_level, color=ak_colors[ak_level]) | |
ax[-1, j].legend() | |
fig.tight_layout() | |
fig.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_1.pdf")) | |
print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_1.pdf")) | |
''' | |
fig, ax = plt.subplots(1, len(results_all)*len(radius) + len(radius), figsize=(7 * len(results_all)*len(radius)+len(radius), 5)) | |
for i, model in enumerate(results_all): | |
for j, R in enumerate(radius): | |
#if r not in results_all[model][700][20][0.3]["scouting"]: | |
# continue | |
# for each training dataset | |
index = len(radius)*i + j | |
ax[index].set_title(model + " R={}".format(R)) | |
matrix_plot(results_all[model], "Greens", r"PL/GL F1 score", ax=ax[index], metric_comp_func=lambda r: r["PL+ghosts"][R]/r["scouting+ghosts"][R]) | |
for i, R in enumerate(radius): | |
index = len(radius)*len(results_all) + i | |
ax[index].set_title("AK R={}".format(R)) | |
matrix_plot(results_all_ak, "Greens", r"PL/GL F1 score", ax=ax[index], metric_comp_func=lambda r: r["PL"][R]/r["GL"][R]) | |
fig.tight_layout() | |
fig.savefig(out_file_PG) | |
print("Saved to", out_file_PG) | |
1/0 | |
#print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf")) | |
#print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf")) | |
########### Now save the above plot with objectness score applied | |
if args.threshold_obj_score != -1: | |
fig, ax = plt.subplots(3, len(models), figsize=(sz * len(models), sz * 3)) | |
for i, model in tqdm(enumerate(models)): | |
output_path = os.path.join(path, model, "count_matched_quarks") | |
if not os.path.exists(os.path.join(output_path, "result.pkl")): | |
continue | |
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) | |
#result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb")) | |
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) | |
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) | |
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
result_PR_thresholds = pickle.load(open(os.path.join(output_path, "result_PR_thresholds.pkl"), "rb")) | |
#thresholds = sorted(list(result_PR_thresholds[900][20][0.3].keys())) | |
#thresholds = np.array(thresholds) | |
# now linearly interpolate the thresholds and set the j according to args.threshold_obj_score | |
j = np.argmin(np.abs(thresholds - args.threshold_obj_score)) | |
print("Thresholds", thresholds) | |
print("Chose j=", j, "for threshold", args.threshold_obj_score, "(effectively it's", thresholds[j], ")") | |
def wrap(r): | |
# compute [precision, recall] array from [n_relevant_retrieved, all_retrieved, all_relevant] | |
if r[1] == 0 or r[2] == 0: | |
return [0, 0] | |
return [r[0] / r[1], r[0] / r[2]] | |
matrix_plot(result_PR_thresholds, "Oranges", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: wrap(r[j])[0], ax=ax[0, i]) | |
matrix_plot(result_PR_thresholds, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: wrap(r[j])[1], ax=ax[1, i]) | |
matrix_plot(result_PR_thresholds, "Purples", r"$F_1$ score", metric_comp_func = lambda r: 2 * wrap(r[j])[0] * wrap(r[j])[1] / (wrap(r[j])[0] + wrap(r[j])[1]), ax=ax[2, i]) | |
ax[0, i].set_title(model) | |
ax[1, i].set_title(model) | |
ax[2, i].set_title(model) | |
fig.tight_layout() | |
fig.savefig(out_file_PR_OS) | |
print("Saved to", out_file_PR_OS) | |
################ | |
# UNUSED PLOTS # | |
################ | |
'''fig, ax = plt.subplots(2, len(models), figsize=(sz * len(models), sz * 2)) | |
for i, model in tqdm(enumerate(models)): | |
output_path = os.path.join(path, model, "count_matched_quarks") | |
if not os.path.exists(os.path.join(output_path, "result.pkl")): | |
continue | |
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) | |
#result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb")) | |
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb")) | |
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb")) | |
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
matrix_plot(result, "Blues", "Avg. matched dark quarks / event", ax=ax[0, i]) | |
matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event", ax=ax[1, i]) | |
ax[0, i].set_title(model) | |
ax[1, i].set_title(model) | |
fig.tight_layout() | |
fig.savefig(out_file_avg_number_matched_quarks) | |
print("Saved to", out_file_avg_number_matched_quarks)''' | |
rinvs = [0.3, 0.5, 0.7] | |
sz = 4 | |
fig, ax = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs))) | |
fig_AK, ax_AK = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs))) | |
fig_AK_ratio, ax_AK_ratio = plt.subplots(len(rinvs), 3, figsize=(3*sz, sz*len(rinvs))) | |
to_plot = {} # r_inv -> m_med -> precision, recall, R | |
to_plot_ak = {} # plotting for the AK baseline | |
### Plotting the score vs GT R plots | |
oranges = plt.get_cmap("Oranges") | |
reds = plt.get_cmap("Reds") # Plot a plot for each mass at given r_inv of the precision, recall, F1 score | |
purples = plt.get_cmap("Purples") | |
mDark = 20 | |
for i, rinv in enumerate(rinvs): | |
if rinv not in to_plot: | |
to_plot[rinv] = {} | |
to_plot_ak[rinv] = {} | |
for j, model in enumerate(models): | |
print("Model", model) | |
if model not in radius: | |
continue | |
r = radius[model] | |
output_path = os.path.join(path, model, "count_matched_quarks") | |
if not os.path.exists(os.path.join(output_path, "result_PR.pkl")): | |
continue | |
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb")) | |
#if radius not in to_plot[rinv]: | |
# to_plot[rinv][radius] = {} | |
for k, mMed in enumerate(sorted(result_PR.keys())): | |
if mMed not in to_plot[rinv]: | |
to_plot[rinv][mMed] = {"precision": [], "recall": [], "f1score": [], "R": []} | |
precision = result_PR[mMed][mDark][rinv][0] | |
recall = result_PR[mMed][mDark][rinv][1] | |
f1score = 2 * precision * recall / (precision + recall) | |
to_plot[rinv][mMed]["precision"].append(precision) | |
to_plot[rinv][mMed]["recall"].append(recall) | |
to_plot[rinv][mMed]["f1score"].append(f1score) | |
to_plot[rinv][mMed]["R"].append(r) | |
for mMed in sorted(to_plot[rinv].keys()): | |
# normalize mmed between 0 and 1 (originally between 700 and 3000) | |
mmed = (mMed - 500) / (3000 - 500) | |
r = to_plot[rinv][mMed] | |
print("Model R", r["R"]) | |
scatter_plot(ax[0, i], r["R"], r["precision"], label="m={} GeV".format(round(mMed)), color=oranges(mmed)) | |
scatter_plot(ax[1, i], r["R"], r["recall"], label="m={} GeV".format(round(mMed)), color=reds(mmed)) | |
scatter_plot(ax[2, i], r["R"], r["f1score"], label="m={} GeV".format(round(mMed)), color=purples(mmed)) | |
if not os.path.exists(os.path.join(ak_path, "result_PR_AKX.pkl")): | |
continue | |
result_PR_AKX = pickle.load(open(os.path.join(ak_path, "result_PR_AKX.pkl"), "rb")) | |
result_jet_props_akx = pickle.load(open(os.path.join(ak_path, "result_jet_properties_AKX.pkl"), "rb")) | |
#if radius not in to_plot[rinv]: | |
# to_plot[rinv][radius] = {} | |
for k, mMed in enumerate(sorted(result_PR_AKX.keys())): | |
if mMed not in to_plot_ak[rinv]: | |
to_plot_ak[rinv][mMed] = {"precision": [], "recall": [], "f1score": [], "R": []} | |
rs = sorted(result_PR_AKX[mMed][mDark][rinv].keys()) | |
precision = np.array([result_PR_AKX[mMed][mDark][rinv][k][0] for k in rs]) | |
recall = np.array([result_PR_AKX[mMed][mDark][rinv][k][1] for k in rs]) | |
f1score = 2 * precision * recall / (precision + recall) | |
to_plot_ak[rinv][mMed]["precision"] += list(precision) | |
to_plot_ak[rinv][mMed]["recall"] += list(recall) | |
to_plot_ak[rinv][mMed]["f1score"] += list(f1score) | |
to_plot_ak[rinv][mMed]["R"] += rs | |
for mMed in sorted(to_plot_ak[rinv].keys()): | |
# Normalize mmed between 0 and 1 (originally between 700 and 3000) | |
mmed = (mMed - 500) / (3000 - 500) | |
r = to_plot_ak[rinv][mMed] | |
r_model = to_plot[rinv][mMed] | |
print("AK R", r["R"]) | |
scatter_plot(ax_AK[0, i], r["R"], r["precision"], label="m={} GeV AK".format(round(mMed)), color=oranges(mmed), pattern=".--") | |
scatter_plot(ax_AK[1, i], r["R"], r["recall"], label="m={} GeV AK".format(round(mMed)), color=reds(mmed), pattern=".--") | |
scatter_plot(ax_AK[2, i], r["R"], r["f1score"], label="m={} GeV AK".format(round(mMed)), color=purples(mmed), pattern=".--") | |
# r["R"] has more points than r_model["R"] - pick those from r["R"] that are in r_model["R"] | |
r["R"] = np.array(r["R"]) | |
r["precision"] = np.array(r["precision"]) | |
r["recall"] = np.array(r["recall"]) | |
r["f1score"] = np.array(r["f1score"]) | |
filt = np.isin(r["R"], r_model["R"]) | |
r["R"] = r["R"][filt] | |
r["precision"] = r["precision"][filt] | |
r["recall"] = r["recall"][filt] | |
r["f1score"] = r["f1score"][filt] | |
scatter_plot(ax_AK_ratio[0, i], r["R"], r["precision"]/np.array(r_model["precision"]), label="m={} GeV AK".format(round(mMed)), color=oranges(mmed), pattern=".--") | |
scatter_plot(ax_AK_ratio[1, i], r["R"], r["recall"]/np.array(r_model["recall"]), label="m={} GeV AK".format(round(mMed)), color=reds(mmed), pattern=".--") | |
scatter_plot(ax_AK_ratio[2, i], r["R"], r["f1score"]/np.array(r_model["f1score"]), label="m={} GeV AK".format(round(mMed)), color=purples(mmed), pattern=".--") | |
for ax1 in [ax, ax_AK, ax_AK_ratio]: | |
ax1[0, i].set_title(f"Precision r_inv = {rinv}") | |
ax1[1, i].set_title(f"Recall r_inv = {rinv}") | |
ax1[2, i].set_title(f"F1 score r_inv = {rinv}") | |
ax1[2, i].legend() | |
ax1[1, i].legend() | |
ax1[0, i].legend() | |
ax1[0, i].grid() | |
ax1[1, i].grid() | |
ax1[2, i].grid() | |
ax1[0, i].set_xlabel("GT R") | |
ax1[1, i].set_xlabel("GT R") | |
ax1[2, i].set_xlabel("GT R") | |
ax_AK_ratio[0, i].set_ylabel("Precision (model=1)") | |
ax_AK_ratio[1, i].set_ylabel("Recall (model=1)") | |
ax_AK_ratio[2, i].set_ylabel("F1 score (model=1)") | |
fig.tight_layout() | |
fig_AK.tight_layout() | |
fig.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots.pdf")) | |
fig_AK.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf")) | |
fig_AK_ratio.tight_layout() | |
fig_AK_ratio.savefig(os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf")) | |
print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK.pdf")) | |
print("Saved to", os.path.join(get_path(args.input, "results"), "score_vs_GT_R_plots_AK_ratio.pdf")) | |