import matplotlib.pyplot as plt import torch import os import numpy as np import matplotlib.colors as mcolors from matplotlib import cm from sklearn.metrics import confusion_matrix from src.plotting.histograms import score_histogram, confusion_matrix_plot from src.plotting.plot_coordinates import plot_coordinates from src.layers.object_cond import calc_eta_phi def plot_event_comparison(event, ax=None, special_pfcands_size=1, special_pfcands_color="gray"): eta_dq = event.matrix_element_gen_particles.eta phi_dq = event.matrix_element_gen_particles.phi pt_dq = event.matrix_element_gen_particles.pt eta = event.pfcands.eta phi = event.pfcands.phi pt = event.pfcands.pt mapping = event.pfcands.pf_cand_jet_idx.int().tolist() print("N jets:", len(event.jets)) genjet_eta = event.genjets.eta genjet_phi = event.genjets.phi genjet_pt = event.genjets.pt if ax is None: fig, ax = plt.subplots(1, 2, figsize=(10, 5)) # plot eta, phi and the size of circles proportional to p_t. The colors should be either gray (if not in mapping) or some other color that 'represents' the identified jet colorlist = ["red", "green", "blue", "purple", "orange", "yellow", "black", "pink", "cyan", "brown", "black", "black", "black", "gray"] colors = [] for i in range(len(eta)): colors.append(colorlist[mapping[i]]) colors = np.array(colors) is_special = (event.pfcands.pid.abs() < 4) #markers = ["." if not is_special[i] else "v" for i in range(len(eta))] #ax[0].scatter(eta, phi, s=pt, c=colors) ax[0].scatter(eta[is_special], phi[is_special], s=pt[is_special]*special_pfcands_size, c=special_pfcands_color, marker="v") ax[0].scatter(eta[~is_special], phi[~is_special], s=pt[~is_special], c=colors[~is_special]) ax[0].scatter(eta_dq, phi_dq, s=pt_dq, c="red", marker="^", alpha=1.0) ax[0].scatter(genjet_eta, genjet_phi, marker="*", s=genjet_pt, c="blue", alpha=1.0) #eta_special = event.special_pfcands.eta #phi_special = event.special_pfcands.phi #pt_special = event.special_pfcands.pt #print("N special PFCands:", len(eta_special)) #ax[0].scatter(eta_special, phi_special, s=pt_special*special_pfcands_size, c=special_pfcands_color, marker="v") # "special" PFCands - electrons, muons, photons satisfying certain criteria # Display the jets as a circle with R=0.5 jet_eta = event.jets.eta jet_phi = event.jets.phi for i in range(len(jet_eta)): circle = plt.Circle((jet_eta[i], jet_phi[i]), 0.5, color="red", fill=False) ax[0].add_artist(circle) ax[0].set_xlabel(r"$\eta$") ax[0].set_ylabel(r"$\phi$") ax[0].set_title("PFCands with Jets") if event.fatjets is not None: colors = [] for i in range(len(eta)): colors.append(colorlist[mapping[i]]) colors = np.array(colors) is_special = (event.pfcands.pid.abs() < 4) ax[1].scatter(eta[is_special], phi[is_special], s=pt[is_special] * special_pfcands_size, c=colors[is_special], marker="v") ax[1].scatter(eta[~is_special], phi[~is_special], s=pt[~is_special], c=colors[~is_special]) ax[1].scatter(eta_dq, phi_dq, s=pt_dq, c="red", marker="^", alpha=1.0) ax[1].scatter(genjet_eta, genjet_phi, marker="*", s=genjet_pt, c="blue", alpha=1.0) ax[1].set_xlabel(r"$\eta$") ax[1].set_ylabel(r"$\phi$") ax[1].set_title("PFCands with FatJets") # Plot the fatjets as a circle with R=0.8 around the center of the fatjet fatjet_eta = event.fatjets.eta fatjet_phi = event.fatjets.phi fatjet_R = 0.8 for i in range(len(fatjet_eta)): circle = plt.Circle((fatjet_eta[i], fatjet_phi[i]), fatjet_R, color="red", fill=False) ax[1].add_artist(circle) # even aspect ratio ax[1].set_aspect("equal") ax[0].set_aspect("equal") if ax is not None: fig.tight_layout() return fig def plot_event(event, colors="gray", custom_coords=None, ax=None, jets=True, pfcands="pfcands"): # plots event onto the specified ax. # :colors: color of the pfcands # :colors_special: color of the special pfcands # :ax: matplotlib ax object to plot onto # :custom_coords: Plot eta and phi from custom_coords instead of event.pfcands. make_fig = ax is None if ax is None: fig, ax = plt.subplots(1, 1, figsize=(5, 5)) eta = getattr(event, pfcands).eta phi = getattr(event, pfcands).phi pt = getattr(event, pfcands).pt #eta_special = event.special_pfcands.eta #phi_special = event.special_pfcands.phi #pt_special = event.special_pfcands.pt if custom_coords: eta = custom_coords[0] phi = custom_coords[1] #if len(eta_special): # eta_special = eta[-len(eta_special):] # phi_special = phi[-len(phi_special):] # eta = eta[:-len(eta_special)] # phi = phi[:-len(eta_special)] #genjet_eta = event.genjets.eta #genjet_phi = event.genjets.phi #genjet_pt = event.genjets.pt #if len(eta_special): # colors_special = colors[-len(eta_special):] # colors = colors[:-len(eta_special)] # print("Colors_special", colors_special) # assert len(colors) == len(phi) # assert len(colors_special) == len(eta_special) ax.scatter(eta, phi, s=pt, c=colors, alpha=0.7) if hasattr(event, "matrix_element_gen_particles") and event.matrix_element_gen_particles is not None: eta_dq = event.matrix_element_gen_particles.eta phi_dq = event.matrix_element_gen_particles.phi pt_dq = event.matrix_element_gen_particles.pt ax.scatter(eta_dq, phi_dq, s=pt_dq, c="red", marker="^", alpha=0.5) # Dark quarks #ax.scatter(genjet_eta, genjet_phi, marker="*", s=genjet_pt, c="blue", alpha=0.5) #if len(eta_special): # ax.scatter(eta_special, phi_special, s=pt_special, c=colors_special, marker="v") if jets: #jet_eta = event.fatjets.eta #jet_phi = event.fatjets.phi #for i in range(len(jet_eta)): # circle = plt.Circle((jet_eta[i], jet_phi[i]), 0.8, color="red", fill=False) # ax.add_artist(circle) if hasattr(event, "model_jets") and event.model_jets is not None: model_jet_eta = event.model_jets.eta model_jet_phi = event.model_jets.phi obj_score = None if hasattr(event.model_jets, "obj_score"): obj_score = event.model_jets.obj_score for i in range(len(model_jet_eta)): circle = plt.Circle((model_jet_eta[i], model_jet_phi[i]), 0.77, color="blue", fill=False, alpha=.7) ax.add_artist(circle) # plot text with obj score if obj_score is not None: ax.text(model_jet_eta[i]+0.2, model_jet_phi[i]-0.2, "o.s.=" + str(round(torch.sigmoid(obj_score[i]).item(), 2)), color="gray", fontsize=10, alpha=0.5) if hasattr(event, "fastjet_jets") and event.fastjet_jets is not None: fj_r = 0.8 model_jet_eta = event.fastjet_jets[fj_r].eta model_jet_phi = event.fastjet_jets[fj_r].phi for i in range(len(model_jet_eta)): circle = plt.Circle((model_jet_eta[i], model_jet_phi[i]), 0.74, color="green", fill=False, alpha=.7) ax.add_artist(circle) ax.set_xlabel(r"$\eta$") ax.set_ylabel(r"$\phi$") ax.set_aspect("equal") if make_fig: fig.tight_layout() return fig def get_idx_for_event(obj, i): return obj.batch_number[i], obj.batch_number[i+1] def get_labels_jets(b, pfcands, jets): # b: Batch of events R = 0.8 labels = torch.zeros(len(pfcands)).long() for i in range(len(b)): s, e = get_idx_for_event(jets, i) dq_eta = jets.eta[s:e] dq_phi = jets.phi[s:e] if s == e: continue s, e = get_idx_for_event(pfcands, i) pfcands_eta = pfcands.eta[s:e] pfcands_phi = pfcands.phi[s:e] # calculate the distance matrix between each dark quark and pfcands dist_matrix = torch.cdist( torch.stack([dq_eta, dq_phi], dim=1), torch.stack([pfcands_eta, pfcands_phi], dim=1), p=2 ) dist_matrix = dist_matrix.T closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1) closest_quark_idx[closest_quark_dist > R] = -1 labels[s:e] = closest_quark_idx return (labels >= 0).float() def plot_batch_eval_OC(event_batch, y_true, y_pred, batch_idx, filename, args, batch, dropped_batches): # Plot the batch, together with nice colors with object condensation GT and betas max_events = 5 sz = 10 assert len(y_true) == len(y_pred), f"y_true: {len(y_true)}, y_pred: {len(y_pred)}" if args.beta_type == "pt+bc": n_columns = 6 y_true_bc = (y_true >= 0).int() #score_histogram(y_true_bc, y_pred[:, 3]).savefig(os.path.join(os.path.dirname(filename), "binary_classifier_scores.pdf")) #score_histogram(y_true_bc, (event_batch.pfcands.pf_cand_jet_idx >= 0).float()).savefig( # os.path.join(os.path.dirname(filename), "binary_classifier_scores_AK8.pdf")) #score_histogram(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.fatjets)).savefig( # os.path.join(os.path.dirname(filename), "binary_classifier_scores_radius_FatJets.pdf")) #score_histogram(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.genjets)).savefig( # os.path.join(os.path.dirname(filename), "binary_classifier_scores_radius_GenJets.pdf")) #fig, ax = plt.subplots(1, 3, figsize=(3*sz/2, sz/2)) #confusion_matrix_plot(y_true_bc, y_pred[:, 3] > 0.5, ax[0]) #ax[0].set_title("Classifier (cut at 0.5)") #confusion_matrix_plot(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.fatjets), ax[2]) #ax[2].set_title("FatJets") #confusion_matrix_plot(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.genjets), ax[1]) #ax[1].set_title("GenJets") #fig.tight_layout() #fig.savefig(os.path.join(os.path.dirname(filename), "conf_matrices.pdf")) else: n_columns = 4 fig, ax = plt.subplots(max_events, n_columns, figsize=(n_columns * sz, sz * max_events)) # columns: Input coords, colored by beta ; Input coords, colored by GT labels; model coords, colored by beta; model coords, colored by GT labels print("N events") for i in range(event_batch.n_events): if i >= max_events: break if i not in dropped_batches: continue event = event_batch[i] filt = batch_idx == i y_true_event = y_true[filt] y_pred_event = y_pred[filt] if args.beta_type == "default": betas = y_pred_event[filt, 3] elif args.beta_type == "pt": betas = event.pfcands.pt elif args.beta_type == "pt+bc": betas = event.pfcands.pt classifier_labels = y_pred_event[:, 3] p_xyz = y_pred_event[:, :3] if y_pred_event.shape[1] == 5: p_xyz = y_pred_event[:, 1:4] e = y_pred_event[:, 0] #lorentz_invariant = e**2 - p_xyz.norm(dim=1)**2 #lorentz_invariant_inputs = event.pfcands.E ** 2 - event.pfcands.pxyz.norm(dim=1) ** 2 plot_coordinates(event.pfcands.pxyz, pt=event.pfcands.pt, tidx=y_true_event, outdir=os.path.dirname(filename), filename="input_coords_batch_" + str(batch) + "_event_" + str(i) + ".html") plot_coordinates(p_xyz, pt=event.pfcands.pt, tidx=y_true_event, outdir=os.path.dirname(filename), filename="model_coords_batch_" + str(batch) + "_event_" + str(i) + ".html") y_true_event = y_true_event.tolist() clist = ['#1f78b4', '#b3df8a', '#33a02c', '#fb9a99', '#e31a1c', '#fdbe6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928'] colors = { -1: "gray", 0: clist[0], 1: clist[1], 2: clist[2], 3: clist[3] } eta, phi = calc_eta_phi(p_xyz, return_stacked=False) plot_event(event, colors=plt.cm.brg(betas), ax=ax[i, 0]) cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=plt.cm.brg), ax=ax[i, 0]) # How to specify the palette? ax[i, 0].set_title(r"input coords, $\beta$ colors") cbar.set_label(r"$\beta$") plot_event(event, colors=[colors[i] for i in y_true_event], ax=ax[i, 1]) ax[i, 1].set_title("input coords, GT colors") plot_event(event, custom_coords=[eta, phi], colors=plt.cm.brg(betas), ax=ax[i, 2], jets=False) #assert betas.min() >= 0 and betas.max() <= 1 ax[i, 2].set_title(r"model coords, $\beta$ colors") cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=plt.cm.brg), ax=ax[i, 2]) ax[i, 3].set_title("model coords, GT colors") cbar.set_label(r"$\beta$") plot_event(event, custom_coords=[eta, phi], colors=[colors[i] for i in y_true_event], ax=ax[i, 3], jets=False) if args.beta_type == "pt+bc": # Create a custom colormap from light gray to dark green colors = [(0.9, 0.9, 0.9), (0.0, 0.5, 0.0)] # RGB for light gray and dark green cmap_name = "lightgray_to_darkgreen" custom_cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors) plot_event(event, custom_coords=[eta, phi], colors=custom_cmap(classifier_labels), ax=ax[i, 5], jets=False) ax[i, 5].set_title(r"model coords, BC label colors") cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=custom_cmap), ax=ax[i, 5]) cbar.set_label("Classifier score") plot_event(event, colors=custom_cmap(classifier_labels), ax=ax[i, 4], jets=False) ax[i, 4].set_title(r"input coords, BC label colors") cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=custom_cmap), ax=ax[i, 4]) cbar.set_label("Classifier score") print("Saving eval figure to", filename) fig.tight_layout() fig.savefig(filename) fig.clear() plt.close(fig)