Spaces:
Sleeping
Sleeping
import pickle | |
import torch | |
import os | |
import matplotlib.pyplot as plt | |
from src.utils.paths import get_path | |
from src.utils.utils import CPU_Unpickler | |
from pathlib import Path | |
from src.plotting.histograms import score_histogram, per_pt_score_histogram, plot_roc_curve, confusion_matrix_plot | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", type=str, required=True) | |
args = parser.parse_args() | |
input_dir = get_path(args.input, "results") | |
# for rinv=0.7, see /work/gkrzmanc/jetclustering/results/train/Test_betaPt_BC_rinv07_2025_01_03_15_38_58 | |
# for L-GATr: /work/gkrzmanc/jetclustering/results/train/Test_LGATr_all_datasets_2025_01_08_19_27_54 | |
def plot_score_histograms(result, eval_path): | |
pt = result["pt"] | |
y_true = (result["GT_cluster"] >= 0) | |
y_pred = result["pred"][:, -1] | |
score_histogram(y_true, y_pred, sz=5).savefig(os.path.join(eval_path, "binary_classifier_scores.pdf")) | |
per_pt_score_histogram(y_true, y_pred, pt).savefig(os.path.join(eval_path, "binary_classifier_scores_per_pt.pdf")) | |
plot_roc_curve(y_true, y_pred).savefig(os.path.join(eval_path, "roc_curve.pdf")) | |
import numpy as np | |
def plot_four_momentum_spectrum(result, eval_path): | |
y_true = (result["GT_cluster"] >= 0) | |
y_pred = result["pred"][:, :4] | |
mass_squared = y_pred[:, 0]**2 - y_pred[:, 1]**2 - y_pred[:, 2]**2 - y_pred[:, 3]**2 | |
signal_masses = mass_squared[y_true] | |
bkg_masses = mass_squared[~y_true] | |
all_masses = mass_squared | |
fig, ax = plt.subplots() | |
bins = np.linspace(-25, 25, 200) | |
#ax.hist(signal_masses, bins=bins, histtype="step", label="Signal") | |
#ax.hist(bkg_masses, bins=bins, histtype="step", label="Background") | |
ax.hist(all_masses, bins=bins, histtype="step", label="All") | |
ax.set_xlabel("m^2") | |
ax.set_yscale("log") | |
ax.set_ylabel("count") | |
ax.legend() | |
fig.savefig(os.path.join(eval_path, "mass_squared.pdf")) | |
def plot_cm(result, eval_path): | |
# Confusion matrices | |
y_true = (result["GT_cluster"] >= 0) | |
y_pred = result["pred"][:, 3] | |
pt = result["pt"] | |
sz = 5 | |
fig, ax = plt.subplots(1, 3, figsize=(3*sz/2, sz/2)) | |
confusion_matrix_plot(y_true, y_pred > 0.5, ax[0]) | |
ax[0].set_title("Classifier (cut at 0.5)") | |
confusion_matrix_plot(y_true, result["radius_cluster_FatJets"], ax[2]) | |
ax[2].set_title("FatJets") | |
confusion_matrix_plot(y_true, result["radius_cluster_GenJets"], ax[1]) | |
ax[1].set_title("GenJets") | |
fig.tight_layout() | |
fig.savefig(os.path.join(eval_path, "confusion_matrix.pdf")) | |
for file in os.listdir(input_dir): | |
print("File:", file) | |
filename = get_path(os.path.join(input_dir, file),"results") | |
if file.startswith("eval_") and file.endswith(".pkl"): | |
print("Plotting file", filename) | |
result = CPU_Unpickler(open(filename, "rb")).load() | |
eval_path = os.path.join(os.path.dirname(filename), "full_eval_" + file.split("_")[1].split(".")[0]) | |
print(result.keys()) | |
Path(eval_path).mkdir(parents=True, exist_ok=True) | |
def plotting_blueprint(result, eval_path): | |
pass | |
#plotting_jobs = [plot_score_histograms, plot_cm] | |
plotting_jobs = [plot_four_momentum_spectrum] | |
from time import time | |
for job in plotting_jobs: | |
t0 = time() | |
print("Starting plotting job", job.__name__) | |
try: | |
job(result, eval_path) | |
except Exception as e: | |
print(f"Error in {job.__name__}: {e}") | |
# print the traceback of the exception | |
import traceback | |
traceback.print_exc() | |
print(f"{job.__name__} took {time()-t0:.2f}s") | |