File size: 3,680 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pickle
import torch
import os
import matplotlib.pyplot as plt
from src.utils.paths import get_path
from src.utils.utils import CPU_Unpickler
from pathlib import Path
from src.plotting.histograms import score_histogram, per_pt_score_histogram, plot_roc_curve, confusion_matrix_plot
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
args = parser.parse_args()

input_dir = get_path(args.input, "results")

# for rinv=0.7, see /work/gkrzmanc/jetclustering/results/train/Test_betaPt_BC_rinv07_2025_01_03_15_38_58
# for L-GATr: /work/gkrzmanc/jetclustering/results/train/Test_LGATr_all_datasets_2025_01_08_19_27_54


def plot_score_histograms(result, eval_path):
    pt = result["pt"]
    y_true = (result["GT_cluster"] >= 0)
    y_pred = result["pred"][:, -1]
    score_histogram(y_true, y_pred, sz=5).savefig(os.path.join(eval_path, "binary_classifier_scores.pdf"))
    per_pt_score_histogram(y_true, y_pred, pt).savefig(os.path.join(eval_path, "binary_classifier_scores_per_pt.pdf"))
    plot_roc_curve(y_true, y_pred).savefig(os.path.join(eval_path, "roc_curve.pdf"))

import numpy as np

def plot_four_momentum_spectrum(result, eval_path):
    y_true = (result["GT_cluster"] >= 0)
    y_pred = result["pred"][:, :4]
    mass_squared = y_pred[:, 0]**2 - y_pred[:, 1]**2 - y_pred[:, 2]**2 - y_pred[:, 3]**2
    signal_masses = mass_squared[y_true]
    bkg_masses = mass_squared[~y_true]
    all_masses = mass_squared
    fig, ax = plt.subplots()
    bins = np.linspace(-25, 25, 200)
    #ax.hist(signal_masses, bins=bins, histtype="step", label="Signal")
    #ax.hist(bkg_masses, bins=bins, histtype="step", label="Background")
    ax.hist(all_masses, bins=bins, histtype="step", label="All")
    ax.set_xlabel("m^2")
    ax.set_yscale("log")
    ax.set_ylabel("count")
    ax.legend()
    fig.savefig(os.path.join(eval_path, "mass_squared.pdf"))

def plot_cm(result, eval_path):
    # Confusion matrices
    y_true = (result["GT_cluster"] >= 0)
    y_pred = result["pred"][:, 3]
    pt = result["pt"]
    sz = 5
    fig, ax = plt.subplots(1, 3, figsize=(3*sz/2, sz/2))
    confusion_matrix_plot(y_true, y_pred > 0.5, ax[0])
    ax[0].set_title("Classifier (cut at 0.5)")
    confusion_matrix_plot(y_true, result["radius_cluster_FatJets"], ax[2])
    ax[2].set_title("FatJets")
    confusion_matrix_plot(y_true, result["radius_cluster_GenJets"], ax[1])
    ax[1].set_title("GenJets")
    fig.tight_layout()
    fig.savefig(os.path.join(eval_path, "confusion_matrix.pdf"))

for file in os.listdir(input_dir):
    print("File:", file)
    filename = get_path(os.path.join(input_dir, file),"results")
    if file.startswith("eval_") and file.endswith(".pkl"):
        print("Plotting file", filename)
        result = CPU_Unpickler(open(filename, "rb")).load()
        eval_path = os.path.join(os.path.dirname(filename), "full_eval_" + file.split("_")[1].split(".")[0])

        print(result.keys())
        Path(eval_path).mkdir(parents=True, exist_ok=True)

        def plotting_blueprint(result, eval_path):
            pass

        #plotting_jobs = [plot_score_histograms, plot_cm]
        plotting_jobs = [plot_four_momentum_spectrum]
        from time import time

        for job in plotting_jobs:
            t0 = time()
            print("Starting plotting job", job.__name__)
            try:
                job(result, eval_path)
            except Exception as e:
                print(f"Error in {job.__name__}: {e}")
                # print the traceback of the exception
                import traceback
                traceback.print_exc()

            print(f"{job.__name__} took {time()-t0:.2f}s")