jetclustering / scripts /analysis /dataset_stats.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
4.83 kB
import os
from tqdm import tqdm
import argparse
import numpy as np
import pandas as pd
import pickle
from src.dataset.get_dataset import get_iter
from src.utils.paths import get_path
from pathlib import Path
import torch
# This script attempts to open dataset files and prints the number of events in each one.
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--plot-only", action="store_true")
# Plots of stats: total visible energy, visible mass, number of AK8 jets, number of pfcands + special_pfcands
args = parser.parse_args()
path = get_path(args.input, "preprocessed_data")
if args.output == "":
args.output = args.input
output_path = os.path.join(get_path(args.output, "results"), "dataset_stats")
Path(output_path).mkdir(parents=True, exist_ok=True)
if not args.plot_only:
stats = {}
for subdataset in os.listdir(path):
print("-----", subdataset, "-----")
current_path = os.path.join(path, subdataset)
dataset = get_iter(current_path)
n = 0
stats[subdataset] = {"total_visible_E": [], "visible_mass": [], "n_fatjets": [], "n_pfcands": [], "pt": [],
"fatjet_pt": [], "genjet_pt": []}
for data in tqdm(dataset):
n += 1
if args.dataset_cap != -1 and n > args.dataset_cap:
break
n_fatjets = len(data.fatjets)
n_pfcands = len(data.pfcands) #+ len(data.special_pfcands)
total_visible_E = torch.sum(data.pfcands.E) #+ torch.sum(data.special_pfcands.E)
pt = data.pfcands.pt.tolist()
visible_mass = torch.sqrt(torch.sum(data.pfcands.E)**2 - torch.sum(data.pfcands.p)**2)
stats[subdataset]["total_visible_E"].append(total_visible_E)
stats[subdataset]["visible_mass"].append(visible_mass)
stats[subdataset]["n_fatjets"].append(n_fatjets)
stats[subdataset]["n_pfcands"].append(n_pfcands)
stats[subdataset]["pt"] += pt
stats[subdataset]["fatjet_pt"] += data.fatjets.pt.tolist()
stats[subdataset]["genjet_pt"] += data.genjets.pt.tolist()
#stats[subdataset]["n_events"] = dataset.n_events
def get_properties(name):
# get mediator mass, dark quark mass, r_inv from the filename
parts = name.split("_")
mMed = int(parts[1].split("-")[1])
mDark = int(parts[2].split("-")[1])
rinv = float(parts[3].split("-")[1])
return mMed, mDark, rinv
result = {}
for key in stats:
mMed, mDark, rinv = get_properties(key)
if mMed not in result:
result[mMed] = {}
if mDark not in result[mMed]:
result[mMed][mDark] = {}
result[mMed][mDark][rinv] = stats[key]
pickle.dump(result, open(os.path.join(output_path, "result.pkl"), "wb"))
if args.plot_only:
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))
import matplotlib.pyplot as plt
# heatmap plots
mediator_masses = sorted(list(result.keys()))
dark_masses = [20]
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))
def plot_distribution(result, key_name):
print("---> key:", key_name)
fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(3*len(r_invs), 3*len(mediator_masses)))
for i, mMed in enumerate(mediator_masses):
for j, rinv in enumerate(r_invs):
mDark = dark_masses[0]
data = result[mMed][mDark][rinv][key_name]
if key_name == "n_pfcands":
number_of_zeros = len(data) - np.count_nonzero(data)
print(f"Number of zeros in {mMed} {rinv}: {number_of_zeros}")
if key_name == "fatjet_pt":
print("Min fatjet_pt:", min(data))
if key_name == "genjet_pt":
print("Min genjet_pt:", min(data))
number_of_zeros = len(data) - np.count_nonzero(data)
print(f"Number of zeros in {mMed} {rinv}: {number_of_zeros}")
ax[i, j].hist(data, bins=50)
ax[i, j].set_title(f"$m_{{Z'}}$={mMed},$r_{{inv}}$={rinv} ($\Sigma$={int(sum(data))})")
if key_name == "pt":
ax[i, j].set_yscale("log")
# big title
fig.suptitle(key_name)
fig.tight_layout()
fig.savefig(os.path.join(output_path, f"{key_name}.pdf"))
#fig.show()
plot_distribution(result, "total_visible_E")
plot_distribution(result, "visible_mass")
plot_distribution(result, "n_fatjets")
plot_distribution(result, "n_pfcands")
plot_distribution(result, "pt")
plot_distribution(result, "fatjet_pt")
plot_distribution(result, "genjet_pt")