File size: 4,829 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
from tqdm import tqdm
import argparse
import numpy as np
import pandas as pd
import pickle
from src.dataset.get_dataset import get_iter
from src.utils.paths import get_path
from pathlib import Path
import torch
# This script attempts to open dataset files and prints the number of events in each one.

parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
parser.add_argument("--dataset-cap", type=int, default=-1)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--plot-only", action="store_true")

# Plots of stats: total visible energy, visible mass, number of AK8 jets, number of pfcands + special_pfcands

args = parser.parse_args()
path = get_path(args.input, "preprocessed_data")
if args.output == "":
    args.output = args.input
output_path = os.path.join(get_path(args.output, "results"), "dataset_stats")
Path(output_path).mkdir(parents=True, exist_ok=True)

if not args.plot_only:
    stats = {}
    for subdataset in os.listdir(path):
        print("-----", subdataset, "-----")
        current_path = os.path.join(path, subdataset)
        dataset = get_iter(current_path)
        n = 0
        stats[subdataset] = {"total_visible_E": [], "visible_mass": [], "n_fatjets": [], "n_pfcands": [], "pt": [],
                             "fatjet_pt": [], "genjet_pt": []}
        for data in tqdm(dataset):
            n += 1
            if args.dataset_cap != -1 and n > args.dataset_cap:
                break
            n_fatjets = len(data.fatjets)
            n_pfcands = len(data.pfcands) #+ len(data.special_pfcands)
            total_visible_E = torch.sum(data.pfcands.E) #+ torch.sum(data.special_pfcands.E)
            pt = data.pfcands.pt.tolist()
            visible_mass = torch.sqrt(torch.sum(data.pfcands.E)**2 - torch.sum(data.pfcands.p)**2)
            stats[subdataset]["total_visible_E"].append(total_visible_E)
            stats[subdataset]["visible_mass"].append(visible_mass)
            stats[subdataset]["n_fatjets"].append(n_fatjets)
            stats[subdataset]["n_pfcands"].append(n_pfcands)
            stats[subdataset]["pt"] += pt
            stats[subdataset]["fatjet_pt"] += data.fatjets.pt.tolist()
            stats[subdataset]["genjet_pt"] += data.genjets.pt.tolist()
        #stats[subdataset]["n_events"] = dataset.n_events
    def get_properties(name):
        # get mediator mass, dark quark mass, r_inv from the filename
        parts = name.split("_")
        mMed = int(parts[1].split("-")[1])
        mDark = int(parts[2].split("-")[1])
        rinv = float(parts[3].split("-")[1])
        return mMed, mDark, rinv
    result = {}
    for key in stats:
        mMed, mDark, rinv = get_properties(key)
        if mMed not in result:
            result[mMed] = {}
        if mDark not in result[mMed]:
            result[mMed][mDark] = {}
        result[mMed][mDark][rinv] = stats[key]
    pickle.dump(result, open(os.path.join(output_path, "result.pkl"), "wb"))
if args.plot_only:
    result = pickle.load(open(os.path.join(output_path, "result.pkl"), "rb"))

import matplotlib.pyplot as plt
# heatmap plots
mediator_masses = sorted(list(result.keys()))
dark_masses = [20]
r_invs = sorted(list(set([rinv for mMed in result for mDark in result[mMed] for rinv in result[mMed][mDark]])))

def plot_distribution(result, key_name):
    print("---> key:", key_name)
    fig, ax = plt.subplots(len(mediator_masses), len(r_invs), figsize=(3*len(r_invs), 3*len(mediator_masses)))
    for i, mMed in enumerate(mediator_masses):
        for j, rinv in enumerate(r_invs):
            mDark = dark_masses[0]
            data = result[mMed][mDark][rinv][key_name]
            if key_name == "n_pfcands":
                number_of_zeros = len(data) - np.count_nonzero(data)
                print(f"Number of zeros in {mMed} {rinv}: {number_of_zeros}")
            if key_name == "fatjet_pt":
                print("Min fatjet_pt:", min(data))
            if key_name == "genjet_pt":
                print("Min genjet_pt:", min(data))
                number_of_zeros = len(data) - np.count_nonzero(data)
                print(f"Number of zeros in {mMed} {rinv}: {number_of_zeros}")
            ax[i, j].hist(data, bins=50)
            ax[i, j].set_title(f"$m_{{Z'}}$={mMed},$r_{{inv}}$={rinv} ($\Sigma$={int(sum(data))})")
            if key_name == "pt":
                ax[i, j].set_yscale("log")
    # big title
    fig.suptitle(key_name)
    fig.tight_layout()
    fig.savefig(os.path.join(output_path, f"{key_name}.pdf"))
    #fig.show()
plot_distribution(result, "total_visible_E")
plot_distribution(result, "visible_mass")
plot_distribution(result, "n_fatjets")
plot_distribution(result, "n_pfcands")
plot_distribution(result, "pt")
plot_distribution(result, "fatjet_pt")
plot_distribution(result, "genjet_pt")