Spaces:
Sleeping
Sleeping
import os | |
from tqdm import tqdm | |
import argparse | |
import numpy as np | |
import pandas as pd | |
import pickle | |
from src.dataset.get_dataset import get_iter | |
from src.utils.paths import get_path | |
from pathlib import Path | |
import torch | |
# This script attempts to open dataset files and prints the number of events in each one. | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", type=str, required=True) | |
parser.add_argument("--dataset-cap", type=int, default=-1) | |
parser.add_argument("--output", type=str, default="") | |
parser.add_argument("--plot-only", action="store_true") | |
# Plots of stats: total visible energy, visible mass, number of AK8 jets, number of pfcands + special_pfcands | |
args = parser.parse_args() | |
path = get_path(args.input, "preprocessed_data") | |
if args.output == "": | |
args.output = args.input | |
output_path = os.path.join(get_path(args.output, "results"), "dataset_stats") | |
Path(output_path).mkdir(parents=True, exist_ok=True) | |
if not args.plot_only: | |
stats = {} | |
for subdataset in os.listdir(path): | |
print("-----", subdataset, "-----") | |
current_path = os.path.join(path, subdataset) | |
dataset = get_iter(current_path) | |
n = 0 | |
stats[subdataset] = {"total_visible_E": [], "visible_mass": [], "n_fatjets": [], "n_pfcands": [], "pt": [], | |
"fatjet_pt": [], "genjet_pt": []} | |
for data in tqdm(dataset): | |
n += 1 | |
if args.dataset_cap != -1 and n > args.dataset_cap: | |
break | |
n_fatjets = len(data.fatjets) | |
n_pfcands = len(data.pfcands) #+ len(data.special_pfcands) | |
total_visible_E = torch.sum(data.pfcands.E) #+ torch.sum(data.special_pfcands.E) | |
pt = data.pfcands.pt.tolist() | |
visible_mass = torch.sqrt(torch.sum(data.pfcands.E)**2 - torch.sum(data.pfcands.p)**2) | |
stats[subdataset]["total_visible_E"].append(total_visible_E) | |
stats[subdataset]["visible_mass"].append(visible_mass) | |
stats[subdataset]["n_fatjets"].append(n_fatjets) | |
stats[subdataset]["n_pfcands"].append(n_pfcands) | |
stats[subdataset]["pt"] += pt | |
stats[subdataset]["fatjet_pt"] += data.fatjets.pt.tolist() | |
stats[subdataset]["genjet_pt"] += data.genjets.pt.tolist() | |
#stats[subdataset]["n_events"] = dataset.n_events | |
def get_properties(name): | |
# get mediator mass, dark quark mass, r_inv from the filename | |
parts = name.split("_") | |
mMed = int(parts[1].split("-")[1]) | |
mDark = int(parts[2].split("-")[1]) | |
rinv = float(parts[3].split("-")[1]) | |
return mMed, mDark, rinv | |
result = {} | |
for key in stats: | |
mMed, mDark, rinv = get_properties(key) | |
if mMed not in result: | |
result[mMed] = {} | |
if mDark not in result[mMed]: | |
result[mMed][mDark] = {} | |
result[mMed][mDark][rinv] = stats[key] | |
pickle.dump(result, open(os.path.join(output_path, "result.pkl"), "wb")) | |
if args.plot_only: | |
result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb")) | |
import matplotlib.pyplot as plt | |
# heatmap plots | |
mediator_masses = sorted(list(result.keys())) | |
dark_masses = [20] | |
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]]))) | |
def plot_distribution(result, key_name): | |
print("---> key:", key_name) | |
fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(3*len(r_invs), 3*len(mediator_masses))) | |
for i, mMed in enumerate(mediator_masses): | |
for j, rinv in enumerate(r_invs): | |
mDark = dark_masses[0] | |
data = result[mMed][mDark][rinv][key_name] | |
if key_name == "n_pfcands": | |
number_of_zeros = len(data) - np.count_nonzero(data) | |
print(f"Number of zeros in {mMed} {rinv}: {number_of_zeros}") | |
if key_name == "fatjet_pt": | |
print("Min fatjet_pt:", min(data)) | |
if key_name == "genjet_pt": | |
print("Min genjet_pt:", min(data)) | |
number_of_zeros = len(data) - np.count_nonzero(data) | |
print(f"Number of zeros in {mMed} {rinv}: {number_of_zeros}") | |
ax[i, j].hist(data, bins=50) | |
ax[i, j].set_title(f"$m_{{Z'}}$={mMed},$r_{{inv}}$={rinv} ($\Sigma$={int(sum(data))})") | |
if key_name == "pt": | |
ax[i, j].set_yscale("log") | |
# big title | |
fig.suptitle(key_name) | |
fig.tight_layout() | |
fig.savefig(os.path.join(output_path, f"{key_name}.pdf")) | |
#fig.show() | |
plot_distribution(result, "total_visible_E") | |
plot_distribution(result, "visible_mass") | |
plot_distribution(result, "n_fatjets") | |
plot_distribution(result, "n_pfcands") | |
plot_distribution(result, "pt") | |
plot_distribution(result, "fatjet_pt") | |
plot_distribution(result, "genjet_pt") | |