jetclustering / scripts /analysis /count_matched_quarks.py
gregorkrzmanc's picture
.
e75a247
import os
from tqdm import tqdm
import argparse
import numpy as np
import pandas as pd
import pickle
import torch
import time
from src.utils.utils import CPU_Unpickler
from src.dataset.get_dataset import get_iter
from src.plotting.eval_matrix import matrix_plot
from src.utils.paths import get_path
from pathlib import Path
import matplotlib.pyplot as plt
from src.dataset.dataset import EventDataset
# This script attempts to open dataset files and prints the number of events in each one.
R = 0.8
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--augment-soft-particles", "-aug-soft", action="store_true")
parser.add_argument("--plot-only", action="store_true")
parser.add_argument("--jets-object", type=str, default="fatjets")
parser.add_argument("--eval-dir", type=str, default="")
parser.add_argument("--clustering-suffix", type=str, default="") # default: 1020, also want to try 1010 or others...?
parser.add_argument("--pt-jet-cutoff", type=float, default=100.0)
parser.add_argument("--high-eta-only", action="store_true") # eta > 1.5 quarks only
parser.add_argument("--low-eta-only", action="store_true") # eta < 1.5 quarks only
parser.add_argument("--parton-level", "-pl", action="store_true") # To be used together with 'fastjet_jets'
parser.add_argument("--gen-level", "-gl", action="store_true")
args = parser.parse_args()
path = get_path(args.input, "preprocessed_data")
import wandb
api = wandb.Api()
def get_run_by_name(name):
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
runs = api.runs(
path="fcc_ml/svj_clustering",
filters={"display_name": {"$eq": name.strip()}}
)
if runs.length != 1:
return None
return runs[0]
def resolve_preproc_data_path(path):
rel_path = path.split("/preprocessed_data/")[-1]
return get_path(rel_path, "preprocessed_data")
if args.eval_dir:
eval_dir = get_path(args.eval_dir, "results", fallback=True)
dataset_path_to_eval_file = {}
top_folder_name = eval_dir.split("/")[-1]
config = get_run_by_name(top_folder_name).config
for file in os.listdir(eval_dir):
if file.startswith("eval_") and file.endswith(".pkl"):
file_number = file.split("_")[1].split(".")[0]
clustering_file = "clustering_{}.pkl".format(file_number)
if args.clustering_suffix:
clustering_file = "clustering_{}_{}.pkl".format(args.clustering_suffix, file_number)
f = CPU_Unpickler(open(os.path.join(eval_dir, file), "rb")).load()
clustering_file = os.path.join(eval_dir, clustering_file)
if "model_cluster" in f and not args.clustering_suffix:
clustering_file = None
dataset_path_to_eval_file[resolve_preproc_data_path(f["filename"])] = [os.path.join(eval_dir, file), clustering_file]
print(dataset_path_to_eval_file)
if args.output == "":
args.output = args.input
output_path = os.path.join(get_path(args.output, "results"), "count_matched_quarks")
Path(output_path).mkdir(parents=True, exist_ok=True)
def get_bc_scores_for_jets(event):
scores = event.pfcands.bc_scores_pfcands
clusters = event.pfcands.bc_labels_pfcands
selected_clusters_idx = torch.where(event.model_jets.pt > 100)[0]
result = []
for c in selected_clusters_idx:
result.append(scores[clusters == c.item()])
return result
def calculate_m(objects, mt=False):
# set a mask returning only the two highest pt jets
mask = objects.pt.argsort(descending=True)[:2]
total_E = objects.E[mask].sum()
total_pxyz = objects.pxyz[mask].sum(dim=0)
if mt:
return np.sqrt(total_E**2 - total_pxyz[0]**2 - total_pxyz[1]**2).item()
return np.sqrt(total_E**2 - total_pxyz[2]**2 - total_pxyz[1]**2 - total_pxyz[0]**2).item()
thresholds = np.linspace(0.1, 1, 20)
# also add 100 points between 0 and 0.1 at the beginning
thresholds = np.concatenate([np.linspace(0, 0.1, 100), thresholds])
def get_mc_gt_per_event(event):
# get the monte carlo GT pt for the event. This is pt of the particles closer than 0.8 to each of the dark quarks
result = []
dq = [event.matrix_element_gen_particles.eta, event.matrix_element_gen_particles.phi]
for i in range(len(dq[0])):
dq_coords = [dq[0][i], dq[1][i]]
cone_filter = torch.sqrt((event.pfcands.eta - dq_coords[0])**2 + (event.pfcands.phi - dq_coords[1])**2) < 0.8
#cone_filter_special = torch.sqrt(
# (event.special_pfcands.eta - dq_coords[0]) ** 2 + (event.special_pfcands.phi - dq_coords[1]) ** 2) < R
eta_cone, phi_cone, pt_cone = event.pfcands.eta[cone_filter], event.pfcands.phi[cone_filter], event.pfcands.pt[cone_filter]
px_cone = torch.sum(pt_cone * np.cos(phi_cone))
py_cone = torch.sum(pt_cone * np.sin(phi_cone))
pz_cone = torch.sum(pt_cone * np.sinh(eta_cone))
pt_cone = torch.sqrt(px_cone**2 + py_cone**2)
result.append(pt_cone.item())
return result
if not args.plot_only:
n_matched_quarks = {}
unmatched_quarks = {}
n_fake_jets = {} # Number of jets that have not been matched to a quark
bc_scores_matched = {}
bc_scores_unmatched = {}
precision_and_recall = {} # Array of [n_relevant_retrieved, all_retrieved, all_relevant], or in our language, [n_matched_dark_quarks, n_jets, n_dark_quarks]
precision_and_recall_fastjets = {}
pr_obj_score_thresholds = {} # same as precision_and_recall, except it gives a dictionary instead of the array, and the keys are the thresholds for objectness score
mass_resolution = {} # Contains {'m_true': [], 'm_pred': [], 'mt_true': [], 'mt_pred': []} # mt = transverse mass, m = invariant mass
matched_jet_properties = {} # contains {'pt_gen_particle': [], 'pt_mc_truth': [], 'pt_pred': [], 'eta_gen_particle': [], 'eta_mc_truth': [], 'eta_pred': [], 'phi_gen_particle': [], 'phi_mc_truth': [], 'phi_pred': []}
matched_jet_properties_fastjets = {}
is_dq_matched_per_event = {}
dq_pt_per_event = {}
gt_pt_per_event = {}
gt_props_per_event = {"eta": {}, "phi": {}}
print("LISTING DIRECTORY", path, ":", os.listdir(path))
for subdataset in os.listdir(path):
print("-----", subdataset, "-----")
current_path = os.path.join(path, subdataset)
model_clusters_file = None
model_output_file = None
if subdataset not in precision_and_recall:
precision_and_recall[subdataset] = [0, 0, 0]
precision_and_recall_fastjets[subdataset] = {}
matched_jet_properties_fastjets[subdataset] = {}
is_dq_matched_per_event[subdataset] = []
dq_pt_per_event[subdataset] = []
gt_pt_per_event[subdataset] = []
if args.jets_object == "fastjet_jets":
is_dq_matched_per_event[subdataset] = {}
dq_pt_per_event[subdataset] = {}
gt_pt_per_event[subdataset] = {}
for key in gt_props_per_event:
if subdataset not in gt_props_per_event[key]:
gt_props_per_event[key][subdataset] = {}
else:
for key in gt_props_per_event:
if subdataset not in gt_props_per_event[key]:
gt_props_per_event[key][subdataset] = []
pr_obj_score_thresholds[subdataset] = {}
for i in range(len(thresholds)):
pr_obj_score_thresholds[subdataset][i] = [0, 0, 0]
if subdataset not in mass_resolution:
mass_resolution[subdataset] = {'m_true': [], 'm_pred': [], 'mt_true': [], 'mt_pred': [], 'n_jets': []}
if args.eval_dir:
if current_path not in dataset_path_to_eval_file:
print("Skipping", current_path)
print(dataset_path_to_eval_file)
continue
model_clusters_file = dataset_path_to_eval_file[current_path][1]
model_output_file = dataset_path_to_eval_file[current_path][0]
#dataset = get_iter(current_path, model_clusters_file=model_clusters_file, model_output_file=model_output_file,
# include_model_jets_unfiltered=True)
fastjet_R = None
if args.jets_object == "fastjet_jets":
fastjet_R = np.array([0.8])
config = {"parton_level": args.parton_level, "gen_level": args.gen_level}
print("Config:", config)
dataset = EventDataset.from_directory(current_path, model_clusters_file=model_clusters_file,
model_output_file=model_output_file,
include_model_jets_unfiltered=True, fastjet_R=fastjet_R,
parton_level=config.get("parton_level", False), gen_level=config.get("gen_level", False),
aug_soft=args.augment_soft_particles, seed=1000000, pt_jet_cutoff=args.pt_jet_cutoff)
n = 0
for x in tqdm(range(len(dataset))):
data = dataset[x]
if data is None:
print("Skipping", x)
continue
#try:
# data = dataset[x]
#except:
# print("Exception")
# break # skip this event
jets_object = data.__dict__[args.jets_object]
n += 1
if args.dataset_cap != -1 and n > args.dataset_cap:
break
if args.high_eta_only and torch.max(torch.abs(data.matrix_element_gen_particles.eta)) < 1.5:
continue
if args.low_eta_only and torch.max(torch.abs(data.matrix_element_gen_particles.eta)) > 1.5:
continue
if not args.jets_object == "fastjet_jets":
jets = [jets_object.eta, jets_object.phi]
dq = [data.matrix_element_gen_particles.eta, data.matrix_element_gen_particles.phi]
# calculate deltaR between each jet and each quark
distance_matrix = np.zeros((len(jets_object), len(data.matrix_element_gen_particles)))
for i in range(len(jets_object)):
for j in range(len(data.matrix_element_gen_particles)):
deta = jets[0][i] - dq[0][j]
dphi = abs(jets[1][i] - dq[1][j])
if dphi > np.pi:
dphi -= 2 * np.pi #- dphi
distance_matrix[i, j] = np.sqrt(deta**2 + dphi**2)
# row-wise argmin
distance_matrix = distance_matrix.T
#min_distance = np.min(distance_matrix, axis=1)
n_jets = len(jets_object)
precision_and_recall[subdataset][1] += n_jets
precision_and_recall[subdataset][2] += len(data.matrix_element_gen_particles)
if "obj_score" in jets_object.__dict__:
print("Also evaluating using objectness score")
for i in range(len(thresholds)):
filt = torch.sigmoid(jets_object.obj_score) >= thresholds[i]
pr_obj_score_thresholds[subdataset][i][1] += torch.sum(filt).item()
pr_obj_score_thresholds[subdataset][i][2] += len(data.matrix_element_gen_particles)
mass_resolution[subdataset]['m_true'].append(calculate_m(data.matrix_element_gen_particles))
mass_resolution[subdataset]['m_pred'].append(calculate_m(jets_object))
mass_resolution[subdataset]['mt_true'].append(calculate_m(data.matrix_element_gen_particles, mt=True))
mass_resolution[subdataset]['mt_pred'].append(calculate_m(jets_object, mt=True))
mass_resolution[subdataset]['n_jets'].append(n_jets)
if len(jets_object):
if subdataset not in matched_jet_properties:
matched_jet_properties[subdataset] = {'pt_gen_particle': [], 'pt_mc_truth': [], 'pt_pred': [],
'eta_gen_particle': [], 'eta_pred': [],
'phi_gen_particle': [], 'phi_pred': []}
quark_to_jet = np.min(distance_matrix, axis=1)
quark_to_jet_idx = np.argmin(distance_matrix, axis=1)
quark_to_jet[quark_to_jet > R] = -1
n_matched_quarks[subdataset] = n_matched_quarks.get(subdataset, []) + [np.sum(quark_to_jet != -1)]
n_fake_jets[subdataset] = n_fake_jets.get(subdataset, []) + [n_jets - np.sum(quark_to_jet != -1)]
f = quark_to_jet != -1
matched_jet_properties[subdataset]["pt_gen_particle"] += data.matrix_element_gen_particles.pt[f].tolist()
matched_jet_properties[subdataset]["pt_pred"] += jets_object.pt[quark_to_jet_idx[f]].tolist()
matched_jet_properties[subdataset]["eta_gen_particle"] += data.matrix_element_gen_particles.eta[f].tolist()
matched_jet_properties[subdataset]["eta_pred"] += jets_object.eta[quark_to_jet_idx[f]].tolist()
matched_jet_properties[subdataset]["phi_gen_particle"] += data.matrix_element_gen_particles.phi[f].tolist()
matched_jet_properties[subdataset]["phi_pred"] += jets_object.phi[quark_to_jet_idx[f]].tolist()
precision_and_recall[subdataset][0] += np.sum(quark_to_jet != -1)
if "obj_score" in jets_object.__dict__:
for i in range(len(thresholds)):
filt = torch.sigmoid(jets_object.obj_score) >= thresholds[i]
dist_matrix_filt = distance_matrix[:, filt.numpy()]
if filt.sum() == 0:
continue
quark_to_jet_filt = np.min(dist_matrix_filt, axis=1)
quark_to_jet_filt[quark_to_jet_filt > R] = -1
pr_obj_score_thresholds[subdataset][i][0] += np.sum(quark_to_jet_filt != -1)
filt = quark_to_jet == -1
#if args.jets_object == "model_jets":
#matched_jet_idx = sorted(np.argmin(distance_matrix, axis=1)[quark_to_jet != -1])
#unmatched_jet_idx = sorted(list(set(list(range(n_jets))) - set(matched_jet_idx)))
#scores = get_bc_scores_for_jets(data)
#for i in matched_jet_idx:
# bc_scores_matched[subdataset] = bc_scores_matched.get(subdataset, []) + [torch.mean(scores[i]).item()]
#for i in unmatched_jet_idx:
# bc_scores_unmatched[subdataset] = bc_scores_unmatched.get(subdataset, []) + [torch.mean(scores[i]).item()]
else:
n_matched_quarks[subdataset] = n_matched_quarks.get(subdataset, []) + [0]
n_fake_jets[subdataset] = n_fake_jets.get(subdataset, []) + [n_jets]
filt = torch.ones(len(data.matrix_element_gen_particles)).bool()
quark_to_jet = torch.ones(len(data.matrix_element_gen_particles)).long() * -1
is_dq_matched_per_event[subdataset].append(quark_to_jet.tolist())
dq_pt_per_event[subdataset].append(data.matrix_element_gen_particles.pt.tolist())
gt_pt_per_event[subdataset].append(get_mc_gt_per_event(data))
gt_props_per_event["eta"][subdataset].append(data.matrix_element_gen_particles.eta.tolist())
gt_props_per_event["phi"][subdataset].append(data.matrix_element_gen_particles.phi.tolist())
if subdataset not in unmatched_quarks:
unmatched_quarks[subdataset] = {"pt": [], "eta": [], "phi": [], "pt_all": [], "frac_evt_E_matched": [], "frac_evt_E_unmatched": []}
unmatched_quarks[subdataset]["pt"] += data.matrix_element_gen_particles.pt[filt].tolist()
unmatched_quarks[subdataset]["pt_all"] += data.matrix_element_gen_particles.pt.tolist()
unmatched_quarks[subdataset]["eta"] += data.matrix_element_gen_particles.eta[filt].tolist()
unmatched_quarks[subdataset]["phi"] += data.matrix_element_gen_particles.phi[filt].tolist()
visible_E_event = torch.sum(data.pfcands.E) #+ torch.sum(data.special_pfcands.E)
matched_quarks = np.where(quark_to_jet != -1)[0]
for i in range(len(data.matrix_element_gen_particles)):
dq_coords = [dq[0][i], dq[1][i]]
cone_filter = torch.sqrt((data.pfcands.eta - dq_coords[0])**2 + (data.pfcands.phi - dq_coords[1])**2) < R
#cone_filter_special = torch.sqrt(
# (data.special_pfcands.eta - dq_coords[0]) ** 2 + (data.special_pfcands.phi - dq_coords[1]) ** 2) < R
E_in_cone = data.pfcands.E[cone_filter].sum()# + data.special_pfcands.E[cone_filter_special].sum()
if i in matched_quarks:
unmatched_quarks[subdataset]["frac_evt_E_matched"].append(E_in_cone / visible_E_event)
else:
unmatched_quarks[subdataset]["frac_evt_E_unmatched"].append(E_in_cone / visible_E_event)
#print("Number of matched quarks:", np.sum(quark_to_jet != -1))
else:
for key in jets_object:
jets = [jets_object[key].eta, jets_object[key].phi]
dq = [data.matrix_element_gen_particles.eta, data.matrix_element_gen_particles.phi]
# calculate deltaR between each jet and each quark
distance_matrix = np.zeros((len(jets_object[key]), len(data.matrix_element_gen_particles)))
for i in range(len(jets_object[key])):
for j in range(len(data.matrix_element_gen_particles)):
deta = jets[0][i] - dq[0][j]
dphi = abs(jets[1][i] - dq[1][j])
if dphi > np.pi:
dphi -= 2 * np.pi
#elif dphi < -np.pi:
# dphi += 2 * np.pi
assert abs(dphi) <= np.pi, "dphi is not in [-pi, pi] range: {}".format(dphi)
distance_matrix[i, j] = np.sqrt(deta ** 2 + dphi ** 2)
# Row-wise argmin
distance_matrix = distance_matrix.T
# min_distance = np.min(distance_matrix, axis=1)
n_jets = len(jets_object[key])
if key not in precision_and_recall_fastjets[subdataset]:
precision_and_recall_fastjets[subdataset][key] = [0, 0, 0]
if key not in matched_jet_properties_fastjets[subdataset]:
is_dq_matched_per_event[subdataset][key] = []
dq_pt_per_event[subdataset][key] = []
gt_pt_per_event[subdataset][key] = []
for prop in gt_props_per_event:
if key not in gt_props_per_event[prop][subdataset]:
gt_props_per_event[prop][subdataset][key] = []
matched_jet_properties_fastjets[subdataset][key] = {"pt_gen_particle": [], "pt_pred": [],
"eta_gen_particle": [], "eta_pred": [],
"phi_gen_particle": [], "phi_pred": []}
precision_and_recall_fastjets[subdataset][key][1] += n_jets
precision_and_recall_fastjets[subdataset][key][2] += len(data.matrix_element_gen_particles)
if len(jets_object[key]):
quark_to_jet = np.min(distance_matrix, axis=1)
quark_to_jet_idx = np.argmin(distance_matrix, axis=1)
quark_to_jet[quark_to_jet > R] = -1
precision_and_recall_fastjets[subdataset][key][0] += np.sum(quark_to_jet != -1)
f = quark_to_jet != -1
matched_jet_properties_fastjets[subdataset][key]["pt_gen_particle"] += data.matrix_element_gen_particles.pt[f].tolist()
matched_jet_properties_fastjets[subdataset][key]["pt_pred"] += jets_object[key].pt[quark_to_jet_idx[f]].tolist()
matched_jet_properties_fastjets[subdataset][key]["eta_gen_particle"] += data.matrix_element_gen_particles.eta[f].tolist()
matched_jet_properties_fastjets[subdataset][key]["eta_pred"] += jets_object[key].eta[quark_to_jet_idx[f]].tolist()
matched_jet_properties_fastjets[subdataset][key]["phi_gen_particle"] += data.matrix_element_gen_particles.phi[f].tolist()
matched_jet_properties_fastjets[subdataset][key]["phi_pred"] += jets_object[key].phi[quark_to_jet_idx[f]].tolist()
else:
quark_to_jet = torch.ones(len(data.matrix_element_gen_particles)).long() * -1
is_dq_matched_per_event[subdataset][key].append(quark_to_jet.tolist())
dq_pt_per_event[subdataset][key].append(data.matrix_element_gen_particles.pt.tolist())
gt_pt_per_event[subdataset][key].append(get_mc_gt_per_event(data))
gt_props_per_event["eta"][subdataset][key].append(data.matrix_element_gen_particles.eta.tolist())
gt_props_per_event["phi"][subdataset][key].append(data.matrix_element_gen_particles.phi.tolist())
avg_n_matched_quarks = {}
avg_n_fake_jets = {}
for key in n_matched_quarks:
avg_n_matched_quarks[key] = np.mean(n_matched_quarks[key])
avg_n_fake_jets[key] = np.mean(n_fake_jets[key])
def get_properties(name):
if "qcd" in name.lower():
print("QCD file! Not using mMed, mDark, rinv")
return 0, 0, 0
# get mediator mass, dark quark mass, r_inv from the filename
parts = name.strip().strip("/").split("/")[-1].split("_")
try:
mMed = int(parts[1].split("-")[1])
mDark = int(parts[2].split("-")[1])
rinv = float(parts[3].split("-")[1])
except:
# another convention
mMed = int(parts[2].split("-")[1])
mDark = int(parts[3].split("-")[1])
rinv = float(parts[4].split("-")[1])
return mMed, mDark, rinv
result = {}
result_unmatched = {}
result_fakes = {}
result_bc = {}
result_PR = {}
result_PR_AKX = {}
result_PR_thresholds = {}
result_m = {}
result_jet_properties = {}
result_jet_properties_AKX = {}
result_quark_to_jet ={}
result_pt_mc_gt = {}
result_pt_dq = {}
result_props_dq = {"eta": {}, "phi": {}}
if args.jets_object != "fastjet_jets":
for key in avg_n_matched_quarks:
mMed, mDark, rinv = get_properties(key)
if mMed not in result:
result[mMed] = {}
result_unmatched[mMed] = {}
result_fakes[mMed] = {}
result_bc[mMed] = {}
result_PR[mMed] = {}
result_PR_AKX[mMed] = {}
result_PR_thresholds[mMed] = {}
result_m[mMed] = {}
result_jet_properties[mMed] = {}
result_jet_properties_AKX[mMed] = {}
result_quark_to_jet[mMed] = {}
result_pt_mc_gt[mMed] = {}
result_pt_dq[mMed] = {}
for prop in gt_props_per_event:
if mMed not in result_props_dq[prop]:
result_props_dq[prop][mMed] = {}
if mDark not in result[mMed]:
result[mMed][mDark] = {}
result_unmatched[mMed][mDark] = {}
result_fakes[mMed][mDark] = {}
result_bc[mMed][mDark] = {}
result_PR[mMed][mDark] = {}
result_PR_thresholds[mMed][mDark] = {}
result_PR_AKX[mMed][mDark] = {}
result_m[mMed][mDark] = {}
result_jet_properties[mMed][mDark] = {}
result_jet_properties_AKX[mMed][mDark] = {}
result_quark_to_jet[mMed][mDark] = {}
result_pt_mc_gt[mMed][mDark] = {}
result_pt_dq[mMed][mDark] = {}
for prop in gt_props_per_event:
if mDark not in result_props_dq[prop][mMed]:
result_props_dq[prop][mMed][mDark] = {}
result[mMed][mDark][rinv] = avg_n_matched_quarks[key]
result_unmatched[mMed][mDark][rinv] = unmatched_quarks[key]
result_fakes[mMed][mDark][rinv] = avg_n_fake_jets[key]
result_jet_properties[mMed][mDark][rinv] = matched_jet_properties[key]
result_quark_to_jet[mMed][mDark][rinv] = is_dq_matched_per_event[key]
result_pt_mc_gt[mMed][mDark][rinv] = gt_pt_per_event[key]
result_pt_dq[mMed][mDark][rinv] = dq_pt_per_event[key]
for prop in gt_props_per_event:
result_props_dq[prop][mMed][mDark][rinv] = gt_props_per_event[prop][key]
#result_bc[mMed][mDark][rinv] = {
# "matched": bc_scores_matched[key],
# "unmatched": bc_scores_unmatched[key]
#}
result_PR_thresholds[mMed][mDark][rinv] = pr_obj_score_thresholds[key]
if precision_and_recall[key][1] == 0 or precision_and_recall[key][2] == 0:
result_PR[mMed][mDark][rinv] = [0, 0]
print(mMed, mDark, rinv)
print("PR zero", key, precision_and_recall[key])
else:
result_PR[mMed][mDark][rinv] = [precision_and_recall[key][0] / precision_and_recall[key][1], precision_and_recall[key][0] / precision_and_recall[key][2]]
result_m[mMed][mDark][rinv] = {key: np.array(val) for key, val in mass_resolution[key].items()}
if args.jets_object == "fastjet_jets":
r = precision_and_recall_fastjets[key]
if rinv not in result_PR_AKX[mMed][mDark]:
result_PR_AKX[mMed][mDark][rinv] = {}
for k in r:
if r[k][1] == 0 or r[k][2] == 0:
result_PR_AKX[mMed][mDark][rinv][k] = [0, 0]
else:
result_PR_AKX[mMed][mDark][rinv][k] = [r[k][0] / r[k][1], r[k][0] / r[k][2]]
else:
for key in precision_and_recall_fastjets: # key=radius of AK
mMed, mDark, rinv = get_properties(key)
if mMed not in result_PR_AKX:
result_PR_AKX[mMed] = {}
result_jet_properties_AKX[mMed] = {}
result_quark_to_jet[mMed] = {}
result_pt_mc_gt[mMed] = {}
result_pt_dq[mMed] = {}
for prop in result_props_dq:
result_props_dq[prop][mMed] = {}
if mDark not in result_PR_AKX[mMed]:
result_PR_AKX[mMed][mDark] = {}
result_jet_properties_AKX[mMed][mDark] = {}
result_quark_to_jet[mMed][mDark] = {}
result_pt_mc_gt[mMed][mDark] = {}
result_pt_dq[mMed][mDark] = {}
for prop in result_props_dq:
result_props_dq[prop][mMed][mDark] = {}
r = precision_and_recall_fastjets[key]
if rinv not in result_PR_AKX[mMed][mDark]:
result_PR_AKX[mMed][mDark][rinv] = {}
result_jet_properties_AKX[mMed][mDark][rinv] = {}
result_quark_to_jet[mMed][mDark][rinv] = {}
result_pt_mc_gt[mMed][mDark][rinv] = {}
result_pt_dq[mMed][mDark][rinv] = {}
for prop in result_props_dq:
result_props_dq[prop][mMed][mDark][rinv] = {}
for k in r:
result_quark_to_jet[mMed][mDark][rinv][k] = is_dq_matched_per_event[key][k]
result_pt_mc_gt[mMed][mDark][rinv][k] = gt_pt_per_event[key][k]
result_pt_dq[mMed][mDark][rinv][k] = dq_pt_per_event[key][k]
for prop in result_props_dq:
result_props_dq[prop][mMed][mDark][rinv][k] = gt_props_per_event[prop][key][k]
result_jet_properties_AKX[mMed][mDark][rinv][k] = matched_jet_properties_fastjets[key][k]
if r[k][1] == 0 or r[k][2] == 0:
result_PR_AKX[mMed][mDark][rinv][k] = [0, 0]
else:
result_PR_AKX[mMed][mDark][rinv][k] = [r[k][0] / r[k][1], r[k][0] / r[k][2]]
pickle.dump(result_quark_to_jet, open(os.path.join(output_path, "result_quark_to_jet.pkl"), "wb"))
pickle.dump(result_pt_mc_gt, open(os.path.join(output_path, "result_pt_mc_gt.pkl"), "wb"))
pickle.dump(result_pt_dq, open(os.path.join(output_path, "result_pt_dq.pkl"), "wb"))
pickle.dump(result, open(os.path.join(output_path, "result.pkl"), "wb"))
pickle.dump(result_unmatched, open(os.path.join(output_path, "result_unmatched.pkl"), "wb"))
pickle.dump(result_fakes, open(os.path.join(output_path, "result_fakes.pkl"), "wb"))
pickle.dump(result_bc, open(os.path.join(output_path, "result_bc.pkl"), "wb"))
pickle.dump(result_props_dq, open(os.path.join(output_path, "result_props_dq.pkl"), "wb"))
if args.jets_object == "fastjet_jets":
pickle.dump(result_PR_AKX, open(os.path.join(output_path, "result_PR_AKX.pkl"), "wb"))
pickle.dump(result_jet_properties_AKX, open(os.path.join(output_path, "result_jet_properties_AKX.pkl"), "wb"))
pickle.dump(result_PR, open(os.path.join(output_path, "result_PR.pkl"), "wb"))
pickle.dump(result_PR_thresholds, open(os.path.join(output_path, "result_PR_thresholds.pkl"), "wb"))
pickle.dump(result_m, open(os.path.join(output_path, "result_m.pkl"), "wb"))
pickle.dump(result_jet_properties, open(os.path.join(output_path, "result_jet_properties.pkl"), "wb"))
with open(os.path.join(output_path, "eval_done.txt"), "w") as f:
f.write("True")
# Write the number of events to n_events.txt
with open(os.path.join(output_path, "n_events.txt"), "w") as f:
f.write(str(n))
if args.plot_only:
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
result_unmatched = pickle.load(open(os.path.join(output_path, "result_unmatched.pkl"), "rb"))
result_fakes = pickle.load(open(os.path.join(output_path, "result_fakes.pkl"), "rb"))
result_bc = pickle.load(open(os.path.join(output_path, "result_bc.pkl"), "rb"))
result_PR = pickle.load(open(os.path.join(output_path, "result_PR.pkl"), "rb"))
result_PR_thresholds = pickle.load(open(os.path.join(output_path, "result_PR_thresholds.pkl"), "rb"))
if args.jets_object == "fastjet_jets":
print("Only computing fastjet jets - exiting now, the metrics have been saved to disk")
import sys
sys.exit(0)
fig, ax = plt.subplots(3, 1, figsize=(4, 12))
def get_plots_for_params(mMed, mDark, rInv):
precisions = []
recalls = []
f1_scores = []
for i in range(len(thresholds)):
if result_PR_thresholds[mMed][mDark][rInv][i][1] == 0:
precisions.append(0)
else:
precisions.append(result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][1])
if result_PR_thresholds[mMed][mDark][rInv][i][2] == 0:
recalls.append(0)
else:
recalls.append(result_PR_thresholds[mMed][mDark][rInv][i][0] / result_PR_thresholds[mMed][mDark][rInv][i][2])
for i in range(len(thresholds)):
if precisions[i] + recalls[i] == 0:
f1_scores.append(0)
else:
f1_scores.append(2*precisions[i]*recalls[i] / (precisions[i] + recalls[i]))
return precisions, recalls, f1_scores
def plot_for_params(a, b, c):
precisions, recalls, f1_scores = get_plots_for_params(a, b, c)
ax[0].plot(thresholds, precisions, ".--", label=f"mMed={a},rInv={c}")
ax[1].plot(thresholds, recalls, ".--", label=f"mMed={a},rInv={c}")
ax[2].plot(thresholds, f1_scores, ".--", label=f"mMed={a},rInv={c}")
if "qcd" in args.input.lower():
print("QCD dataset - not plotting thresholds")
import sys
sys.exit(0)
plot_for_params(900, 20, 0.3)
plot_for_params(700, 20, 0.7)
#plot_for_params(3000, 20, 0.3)
plot_for_params(900, 20, 0.7)
plot_for_params(1000, 20, 0.3)
ax[0].grid()
ax[1].grid()
ax[2].grid()
ax[0].set_ylabel("Precision")
ax[1].set_ylabel("Recall")
ax[2].set_ylabel("F1 score")
ax[0].legend()
ax[1].legend()
ax[2].legend()
ax[0].set_xscale("log")
ax[1].set_xscale("log")
ax[2].set_xscale("log")
fig.tight_layout()
fig.savefig(os.path.join(output_path, "pr_thresholds.pdf"))
matrix_plot(result, "Blues", "Avg. matched dark quarks / event").savefig(os.path.join(output_path, "avg_matched_dark_quarks.pdf"))
matrix_plot(result_fakes, "Greens", "Avg. unmatched jets / event").savefig(os.path.join(output_path, "avg_unmatched_jets.pdf"))
matrix_plot(result_PR, "Reds", "Precision (N matched dark quarks / N predicted jets)", metric_comp_func = lambda r: r[0]).savefig(os.path.join(output_path, "precision.pdf"))
matrix_plot(result_PR, "Reds", "Recall (N matched dark quarks / N dark quarks)", metric_comp_func = lambda r: r[1]).savefig(os.path.join(output_path, "recall.pdf"))
matrix_plot(result_PR, "Purples", "F_1 score", metric_comp_func = lambda r: 2 * r[0] * r[1] / (r[0] + r[1])).savefig(os.path.join(output_path, "f1_score.pdf"))
dark_masses = [20]
mediator_masses = sorted(list(result.keys()))
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["pt"]
ax[i, j].hist(data, bins=50, histtype="step", label="Unmatched")
ax[i, j].hist(result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["pt_all"], bins=50, histtype="step", label="All")
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("pt")
ax[i, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "unmatched_dark_quarks_pt.pdf"))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data_x = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["eta"]
data_y = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["phi"]
# 2d histogram
ax[i, j].hist2d(data_x, data_y, bins=10, cmap="Blues")
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("unmatched dark quark eta")
ax[i, j].set_ylabel("unmatched dark quark phi")
fig.tight_layout()
fig.savefig(os.path.join(output_path, "unmatched_dark_quarks_eta_phi.pdf"))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_matched"]
data_unmatched = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_unmatched"]
bins = np.linspace(0, 1, 100)
ax[i, j].hist(data_unmatched, bins=bins, histtype="step", label="Unmatched")
ax[i, j].hist(data, bins=bins, histtype="step", label="Matched")
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("E (R<0.8) / event E")
ax[i, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "frac_E_in_cone.pdf"))
fig, ax = plt.subplots(len(r_invs), len(mediator_masses), figsize=(3*len(mediator_masses), 3 * len(r_invs)))
for i in range(len(r_invs)):
for j in range(len(mediator_masses)):
data = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_matched"]
data_unmatched = result_unmatched[mediator_masses[j]][dark_masses[0]][r_invs[i]]["frac_evt_E_unmatched"]
bins = np.linspace(0, 1, 100)
ax[i, j].hist(data_unmatched, bins=bins, histtype="step", label="Unmatched dark quark", density=True)
ax[i, j].hist(data, bins=bins, histtype="step", label="Matched dark quark", density=True)
ax[i, j].set_title(f"mMed = {mediator_masses[j]}, rinv = {r_invs[i]}")
ax[i, j].set_xlabel("E (R<0.8) / event E")
ax[i, j].legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "frac_E_in_cone_density.pdf"))
'''
fig, ax = plt.subplots(figsize=(5, 5))
unmatched = result_bc[900][20][0.3]["unmatched"]
matched = result_bc[900][20][0.3]["matched"]
bins = np.linspace(0, 1, 100)
ax.hist(unmatched, bins=bins, histtype="step", label="Unmatched jet")
ax.hist(matched, bins=bins, histtype="step", label="Matched jet")
ax.set_title("mMed = 900, mDark = 20, rinv = 0.3")
ax.set_xlabel("BC score")
ax.set_ylabel("count")
ax.set_yscale("log")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(output_path, "avg_scores_matched_vs_unmatched_jet.pdf"))
'''