jetclustering / notebooks /comparison_clustering_parton_and_gen_level.py
gregorkrzmanc's picture
.
e75a247
import pickle
import torch
import os
import matplotlib.pyplot as plt
from src.utils.paths import get_path
from src.utils.utils import CPU_Unpickler
from pathlib import Path
import fastjet
from src.dataset.dataset import EventDataset
import numpy as np
from src.plotting.plot_coordinates import plot_coordinates
filename_parton_level = get_path("/work/gkrzmanc/jetclustering/results/train/Eval_no_pid_eval_1_2025_03_05_14_41_16/eval_1.pkl", "results")
result_parton_level = CPU_Unpickler(open(filename_parton_level, "rb")).load()
dataset_parton_level = EventDataset.from_directory(result_parton_level["filename"], mmap=True)
filename_gen_level = get_path("/work/gkrzmanc/jetclustering/results/train/Eval_no_pid_eval_1_2025_03_05_14_40_30/eval_1.pkl", "results")
result_gen_level = CPU_Unpickler(open(filename_gen_level, "rb")).load()
dataset_gen_level = EventDataset.from_directory(result_gen_level["filename"], mmap=True)
filename_pfcands_level = get_path("/work/gkrzmanc/jetclustering/results/train/Eval_no_pid_eval_1_2025_03_05_14_41_38/eval_1.pkl", "results")
result_pfcands_level = CPU_Unpickler(open(filename_pfcands_level, "rb")).load()
dataset_pfcands_level = EventDataset.from_directory(result_pfcands_level["filename"], mmap=True)
EVENT_ID=15
# plotly 3d plot of result["pred"], colored by result["GT_cluster"]
def plot_result(result, dataset_path, save_dir):
filt = result["event_idx"] == EVENT_ID
# normalized coordinates
norm_coords = result["pred"][filt, 1:4] #/ np.linalg.norm(result["pred"][filt, 1:4] , axis=1 ,keepdims=1)
pt = torch.tensor(result["pt"][filt])
clusters_file = get_path(os.path.join(dataset_path, f"clustering_hdbscan_4_05_1.pkl"), "results")
#clusters_file=None
model_clusters = CPU_Unpickler(open(clusters_file, "rb")).load()# torch.tensor(model_clusters[filt])
plot_coordinates(norm_coords, pt, result["GT_cluster"][filt]).write_html(save_dir)
print("-----")
plot_result(result_parton_level, "/work/gkrzmanc/jetclustering/results/train/Eval_no_pid_eval_1_2025_03_05_14_41_16", "/work/gkrzmanc/jetclustering/results/GT_color_parton_level_{}.html".format(EVENT_ID))
plot_result(result_gen_level, "/work/gkrzmanc/jetclustering/results/train/Eval_no_pid_eval_1_2025_03_05_14_40_30", "/work/gkrzmanc/jetclustering/results/GT_color_gen_level_{}.html".format(EVENT_ID))
plot_result(result_pfcands_level, "/work/gkrzmanc/jetclustering/results/train/Eval_no_pid_eval_1_2025_03_05_14_41_38", "/work/gkrzmanc/jetclustering/results/GT_color_pfcands_level_{}.html".format(EVENT_ID))