jetclustering / src /analyze_results.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
3.68 kB
import pickle
import torch
import os
import matplotlib.pyplot as plt
from src.utils.paths import get_path
from src.utils.utils import CPU_Unpickler
from pathlib import Path
from src.plotting.histograms import score_histogram, per_pt_score_histogram, plot_roc_curve, confusion_matrix_plot
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True)
args = parser.parse_args()
input_dir = get_path(args.input, "results")
# for rinv=0.7, see /work/gkrzmanc/jetclustering/results/train/Test_betaPt_BC_rinv07_2025_01_03_15_38_58
# for L-GATr: /work/gkrzmanc/jetclustering/results/train/Test_LGATr_all_datasets_2025_01_08_19_27_54
def plot_score_histograms(result, eval_path):
pt = result["pt"]
y_true = (result["GT_cluster"] >= 0)
y_pred = result["pred"][:, -1]
score_histogram(y_true, y_pred, sz=5).savefig(os.path.join(eval_path, "binary_classifier_scores.pdf"))
per_pt_score_histogram(y_true, y_pred, pt).savefig(os.path.join(eval_path, "binary_classifier_scores_per_pt.pdf"))
plot_roc_curve(y_true, y_pred).savefig(os.path.join(eval_path, "roc_curve.pdf"))
import numpy as np
def plot_four_momentum_spectrum(result, eval_path):
y_true = (result["GT_cluster"] >= 0)
y_pred = result["pred"][:, :4]
mass_squared = y_pred[:, 0]**2 - y_pred[:, 1]**2 - y_pred[:, 2]**2 - y_pred[:, 3]**2
signal_masses = mass_squared[y_true]
bkg_masses = mass_squared[~y_true]
all_masses = mass_squared
fig, ax = plt.subplots()
bins = np.linspace(-25, 25, 200)
#ax.hist(signal_masses, bins=bins, histtype="step", label="Signal")
#ax.hist(bkg_masses, bins=bins, histtype="step", label="Background")
ax.hist(all_masses, bins=bins, histtype="step", label="All")
ax.set_xlabel("m^2")
ax.set_yscale("log")
ax.set_ylabel("count")
ax.legend()
fig.savefig(os.path.join(eval_path, "mass_squared.pdf"))
def plot_cm(result, eval_path):
# Confusion matrices
y_true = (result["GT_cluster"] >= 0)
y_pred = result["pred"][:, 3]
pt = result["pt"]
sz = 5
fig, ax = plt.subplots(1, 3, figsize=(3*sz/2, sz/2))
confusion_matrix_plot(y_true, y_pred > 0.5, ax[0])
ax[0].set_title("Classifier (cut at 0.5)")
confusion_matrix_plot(y_true, result["radius_cluster_FatJets"], ax[2])
ax[2].set_title("FatJets")
confusion_matrix_plot(y_true, result["radius_cluster_GenJets"], ax[1])
ax[1].set_title("GenJets")
fig.tight_layout()
fig.savefig(os.path.join(eval_path, "confusion_matrix.pdf"))
for file in os.listdir(input_dir):
print("File:", file)
filename = get_path(os.path.join(input_dir, file),"results")
if file.startswith("eval_") and file.endswith(".pkl"):
print("Plotting file", filename)
result = CPU_Unpickler(open(filename, "rb")).load()
eval_path = os.path.join(os.path.dirname(filename), "full_eval_" + file.split("_")[1].split(".")[0])
print(result.keys())
Path(eval_path).mkdir(parents=True, exist_ok=True)
def plotting_blueprint(result, eval_path):
pass
#plotting_jobs = [plot_score_histograms, plot_cm]
plotting_jobs = [plot_four_momentum_spectrum]
from time import time
for job in plotting_jobs:
t0 = time()
print("Starting plotting job", job.__name__)
try:
job(result, eval_path)
except Exception as e:
print(f"Error in {job.__name__}: {e}")
# print the traceback of the exception
import traceback
traceback.print_exc()
print(f"{job.__name__} took {time()-t0:.2f}s")