Spaces:
Sleeping
Sleeping
| import pickle | |
| import torch | |
| import os | |
| import matplotlib.pyplot as plt | |
| from src.utils.paths import get_path | |
| from src.utils.utils import CPU_Unpickler | |
| from pathlib import Path | |
| from src.plotting.histograms import score_histogram, per_pt_score_histogram, plot_roc_curve, confusion_matrix_plot | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", type=str, required=True) | |
| args = parser.parse_args() | |
| input_dir = get_path(args.input, "results") | |
| # for rinv=0.7, see /work/gkrzmanc/jetclustering/results/train/Test_betaPt_BC_rinv07_2025_01_03_15_38_58 | |
| # for L-GATr: /work/gkrzmanc/jetclustering/results/train/Test_LGATr_all_datasets_2025_01_08_19_27_54 | |
| def plot_score_histograms(result, eval_path): | |
| pt = result["pt"] | |
| y_true = (result["GT_cluster"] >= 0) | |
| y_pred = result["pred"][:, -1] | |
| score_histogram(y_true, y_pred, sz=5).savefig(os.path.join(eval_path, "binary_classifier_scores.pdf")) | |
| per_pt_score_histogram(y_true, y_pred, pt).savefig(os.path.join(eval_path, "binary_classifier_scores_per_pt.pdf")) | |
| plot_roc_curve(y_true, y_pred).savefig(os.path.join(eval_path, "roc_curve.pdf")) | |
| import numpy as np | |
| def plot_four_momentum_spectrum(result, eval_path): | |
| y_true = (result["GT_cluster"] >= 0) | |
| y_pred = result["pred"][:, :4] | |
| mass_squared = y_pred[:, 0]**2 - y_pred[:, 1]**2 - y_pred[:, 2]**2 - y_pred[:, 3]**2 | |
| signal_masses = mass_squared[y_true] | |
| bkg_masses = mass_squared[~y_true] | |
| all_masses = mass_squared | |
| fig, ax = plt.subplots() | |
| bins = np.linspace(-25, 25, 200) | |
| #ax.hist(signal_masses, bins=bins, histtype="step", label="Signal") | |
| #ax.hist(bkg_masses, bins=bins, histtype="step", label="Background") | |
| ax.hist(all_masses, bins=bins, histtype="step", label="All") | |
| ax.set_xlabel("m^2") | |
| ax.set_yscale("log") | |
| ax.set_ylabel("count") | |
| ax.legend() | |
| fig.savefig(os.path.join(eval_path, "mass_squared.pdf")) | |
| def plot_cm(result, eval_path): | |
| # Confusion matrices | |
| y_true = (result["GT_cluster"] >= 0) | |
| y_pred = result["pred"][:, 3] | |
| pt = result["pt"] | |
| sz = 5 | |
| fig, ax = plt.subplots(1, 3, figsize=(3*sz/2, sz/2)) | |
| confusion_matrix_plot(y_true, y_pred > 0.5, ax[0]) | |
| ax[0].set_title("Classifier (cut at 0.5)") | |
| confusion_matrix_plot(y_true, result["radius_cluster_FatJets"], ax[2]) | |
| ax[2].set_title("FatJets") | |
| confusion_matrix_plot(y_true, result["radius_cluster_GenJets"], ax[1]) | |
| ax[1].set_title("GenJets") | |
| fig.tight_layout() | |
| fig.savefig(os.path.join(eval_path, "confusion_matrix.pdf")) | |
| for file in os.listdir(input_dir): | |
| print("File:", file) | |
| filename = get_path(os.path.join(input_dir, file),"results") | |
| if file.startswith("eval_") and file.endswith(".pkl"): | |
| print("Plotting file", filename) | |
| result = CPU_Unpickler(open(filename, "rb")).load() | |
| eval_path = os.path.join(os.path.dirname(filename), "full_eval_" + file.split("_")[1].split(".")[0]) | |
| print(result.keys()) | |
| Path(eval_path).mkdir(parents=True, exist_ok=True) | |
| def plotting_blueprint(result, eval_path): | |
| pass | |
| #plotting_jobs = [plot_score_histograms, plot_cm] | |
| plotting_jobs = [plot_four_momentum_spectrum] | |
| from time import time | |
| for job in plotting_jobs: | |
| t0 = time() | |
| print("Starting plotting job", job.__name__) | |
| try: | |
| job(result, eval_path) | |
| except Exception as e: | |
| print(f"Error in {job.__name__}: {e}") | |
| # print the traceback of the exception | |
| import traceback | |
| traceback.print_exc() | |
| print(f"{job.__name__} took {time()-t0:.2f}s") | |