jetclustering / src /plotting /plot_event.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
14.4 kB
import matplotlib.pyplot as plt
import torch
import os
import numpy as np
import matplotlib.colors as mcolors
from matplotlib import cm
from sklearn.metrics import confusion_matrix
from src.plotting.histograms import score_histogram, confusion_matrix_plot
from src.plotting.plot_coordinates import plot_coordinates
from src.layers.object_cond import calc_eta_phi
def plot_event_comparison(event, ax=None, special_pfcands_size=1, special_pfcands_color="gray"):
eta_dq = event.matrix_element_gen_particles.eta
phi_dq = event.matrix_element_gen_particles.phi
pt_dq = event.matrix_element_gen_particles.pt
eta = event.pfcands.eta
phi = event.pfcands.phi
pt = event.pfcands.pt
mapping = event.pfcands.pf_cand_jet_idx.int().tolist()
print("N jets:", len(event.jets))
genjet_eta = event.genjets.eta
genjet_phi = event.genjets.phi
genjet_pt = event.genjets.pt
if ax is None:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# plot eta, phi and the size of circles proportional to p_t. The colors should be either gray (if not in mapping) or some other color that 'represents' the identified jet
colorlist = ["red", "green", "blue", "purple", "orange", "yellow", "black", "pink", "cyan", "brown", "black", "black", "black", "gray"]
colors = []
for i in range(len(eta)):
colors.append(colorlist[mapping[i]])
colors = np.array(colors)
is_special = (event.pfcands.pid.abs() < 4)
#markers = ["." if not is_special[i] else "v" for i in range(len(eta))]
#ax[0].scatter(eta, phi, s=pt, c=colors)
ax[0].scatter(eta[is_special], phi[is_special], s=pt[is_special]*special_pfcands_size, c=special_pfcands_color, marker="v")
ax[0].scatter(eta[~is_special], phi[~is_special], s=pt[~is_special], c=colors[~is_special])
ax[0].scatter(eta_dq, phi_dq, s=pt_dq, c="red", marker="^", alpha=1.0)
ax[0].scatter(genjet_eta, genjet_phi, marker="*", s=genjet_pt, c="blue", alpha=1.0)
#eta_special = event.special_pfcands.eta
#phi_special = event.special_pfcands.phi
#pt_special = event.special_pfcands.pt
#print("N special PFCands:", len(eta_special))
#ax[0].scatter(eta_special, phi_special, s=pt_special*special_pfcands_size, c=special_pfcands_color, marker="v")
# "special" PFCands - electrons, muons, photons satisfying certain criteria
# Display the jets as a circle with R=0.5
jet_eta = event.jets.eta
jet_phi = event.jets.phi
for i in range(len(jet_eta)):
circle = plt.Circle((jet_eta[i], jet_phi[i]), 0.5, color="red", fill=False)
ax[0].add_artist(circle)
ax[0].set_xlabel(r"$\eta$")
ax[0].set_ylabel(r"$\phi$")
ax[0].set_title("PFCands with Jets")
if event.fatjets is not None:
colors = []
for i in range(len(eta)):
colors.append(colorlist[mapping[i]])
colors = np.array(colors)
is_special = (event.pfcands.pid.abs() < 4)
ax[1].scatter(eta[is_special], phi[is_special], s=pt[is_special] * special_pfcands_size,
c=colors[is_special], marker="v")
ax[1].scatter(eta[~is_special], phi[~is_special], s=pt[~is_special], c=colors[~is_special])
ax[1].scatter(eta_dq, phi_dq, s=pt_dq, c="red", marker="^", alpha=1.0)
ax[1].scatter(genjet_eta, genjet_phi, marker="*", s=genjet_pt, c="blue", alpha=1.0)
ax[1].set_xlabel(r"$\eta$")
ax[1].set_ylabel(r"$\phi$")
ax[1].set_title("PFCands with FatJets")
# Plot the fatjets as a circle with R=0.8 around the center of the fatjet
fatjet_eta = event.fatjets.eta
fatjet_phi = event.fatjets.phi
fatjet_R = 0.8
for i in range(len(fatjet_eta)):
circle = plt.Circle((fatjet_eta[i], fatjet_phi[i]), fatjet_R, color="red", fill=False)
ax[1].add_artist(circle)
# even aspect ratio
ax[1].set_aspect("equal")
ax[0].set_aspect("equal")
if ax is not None:
fig.tight_layout()
return fig
def plot_event(event, colors="gray", custom_coords=None, ax=None, jets=True, pfcands="pfcands"):
# plots event onto the specified ax.
# :colors: color of the pfcands
# :colors_special: color of the special pfcands
# :ax: matplotlib ax object to plot onto
# :custom_coords: Plot eta and phi from custom_coords instead of event.pfcands.
make_fig = ax is None
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
eta = getattr(event, pfcands).eta
phi = getattr(event, pfcands).phi
pt = getattr(event, pfcands).pt
#eta_special = event.special_pfcands.eta
#phi_special = event.special_pfcands.phi
#pt_special = event.special_pfcands.pt
if custom_coords:
eta = custom_coords[0]
phi = custom_coords[1]
#if len(eta_special):
# eta_special = eta[-len(eta_special):]
# phi_special = phi[-len(phi_special):]
# eta = eta[:-len(eta_special)]
# phi = phi[:-len(eta_special)]
#genjet_eta = event.genjets.eta
#genjet_phi = event.genjets.phi
#genjet_pt = event.genjets.pt
#if len(eta_special):
# colors_special = colors[-len(eta_special):]
# colors = colors[:-len(eta_special)]
# print("Colors_special", colors_special)
# assert len(colors) == len(phi)
# assert len(colors_special) == len(eta_special)
ax.scatter(eta, phi, s=pt, c=colors, alpha=0.7)
if hasattr(event, "matrix_element_gen_particles") and event.matrix_element_gen_particles is not None:
eta_dq = event.matrix_element_gen_particles.eta
phi_dq = event.matrix_element_gen_particles.phi
pt_dq = event.matrix_element_gen_particles.pt
ax.scatter(eta_dq, phi_dq, s=pt_dq, c="red", marker="^", alpha=0.5) # Dark quarks
#ax.scatter(genjet_eta, genjet_phi, marker="*", s=genjet_pt, c="blue", alpha=0.5)
#if len(eta_special):
# ax.scatter(eta_special, phi_special, s=pt_special, c=colors_special, marker="v")
if jets:
#jet_eta = event.fatjets.eta
#jet_phi = event.fatjets.phi
#for i in range(len(jet_eta)):
# circle = plt.Circle((jet_eta[i], jet_phi[i]), 0.8, color="red", fill=False)
# ax.add_artist(circle)
if hasattr(event, "model_jets") and event.model_jets is not None:
model_jet_eta = event.model_jets.eta
model_jet_phi = event.model_jets.phi
obj_score = None
if hasattr(event.model_jets, "obj_score"):
obj_score = event.model_jets.obj_score
for i in range(len(model_jet_eta)):
circle = plt.Circle((model_jet_eta[i], model_jet_phi[i]), 0.77, color="blue", fill=False, alpha=.7)
ax.add_artist(circle)
# plot text with obj score
if obj_score is not None:
ax.text(model_jet_eta[i]+0.2, model_jet_phi[i]-0.2, "o.s.=" + str(round(torch.sigmoid(obj_score[i]).item(), 2)), color="gray", fontsize=10, alpha=0.5)
if hasattr(event, "fastjet_jets") and event.fastjet_jets is not None:
fj_r = 0.8
model_jet_eta = event.fastjet_jets[fj_r].eta
model_jet_phi = event.fastjet_jets[fj_r].phi
for i in range(len(model_jet_eta)):
circle = plt.Circle((model_jet_eta[i], model_jet_phi[i]), 0.74, color="green", fill=False, alpha=.7)
ax.add_artist(circle)
ax.set_xlabel(r"$\eta$")
ax.set_ylabel(r"$\phi$")
ax.set_aspect("equal")
if make_fig:
fig.tight_layout()
return fig
def get_idx_for_event(obj, i):
return obj.batch_number[i], obj.batch_number[i+1]
def get_labels_jets(b, pfcands, jets):
# b: Batch of events
R = 0.8
labels = torch.zeros(len(pfcands)).long()
for i in range(len(b)):
s, e = get_idx_for_event(jets, i)
dq_eta = jets.eta[s:e]
dq_phi = jets.phi[s:e]
if s == e:
continue
s, e = get_idx_for_event(pfcands, i)
pfcands_eta = pfcands.eta[s:e]
pfcands_phi = pfcands.phi[s:e]
# calculate the distance matrix between each dark quark and pfcands
dist_matrix = torch.cdist(
torch.stack([dq_eta, dq_phi], dim=1),
torch.stack([pfcands_eta, pfcands_phi], dim=1),
p=2
)
dist_matrix = dist_matrix.T
closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1)
closest_quark_idx[closest_quark_dist > R] = -1
labels[s:e] = closest_quark_idx
return (labels >= 0).float()
def plot_batch_eval_OC(event_batch, y_true, y_pred, batch_idx, filename, args, batch, dropped_batches):
# Plot the batch, together with nice colors with object condensation GT and betas
max_events = 5
sz = 10
assert len(y_true) == len(y_pred), f"y_true: {len(y_true)}, y_pred: {len(y_pred)}"
if args.beta_type == "pt+bc":
n_columns = 6
y_true_bc = (y_true >= 0).int()
#score_histogram(y_true_bc, y_pred[:, 3]).savefig(os.path.join(os.path.dirname(filename), "binary_classifier_scores.pdf"))
#score_histogram(y_true_bc, (event_batch.pfcands.pf_cand_jet_idx >= 0).float()).savefig(
# os.path.join(os.path.dirname(filename), "binary_classifier_scores_AK8.pdf"))
#score_histogram(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.fatjets)).savefig(
# os.path.join(os.path.dirname(filename), "binary_classifier_scores_radius_FatJets.pdf"))
#score_histogram(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.genjets)).savefig(
# os.path.join(os.path.dirname(filename), "binary_classifier_scores_radius_GenJets.pdf"))
#fig, ax = plt.subplots(1, 3, figsize=(3*sz/2, sz/2))
#confusion_matrix_plot(y_true_bc, y_pred[:, 3] > 0.5, ax[0])
#ax[0].set_title("Classifier (cut at 0.5)")
#confusion_matrix_plot(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.fatjets), ax[2])
#ax[2].set_title("FatJets")
#confusion_matrix_plot(y_true_bc, get_labels_jets(event_batch, event_batch.pfcands, event_batch.genjets), ax[1])
#ax[1].set_title("GenJets")
#fig.tight_layout()
#fig.savefig(os.path.join(os.path.dirname(filename), "conf_matrices.pdf"))
else:
n_columns = 4
fig, ax = plt.subplots(max_events, n_columns, figsize=(n_columns * sz, sz * max_events))
# columns: Input coords, colored by beta ; Input coords, colored by GT labels; model coords, colored by beta; model coords, colored by GT labels
print("N events")
for i in range(event_batch.n_events):
if i >= max_events:
break
if i not in dropped_batches:
continue
event = event_batch[i]
filt = batch_idx == i
y_true_event = y_true[filt]
y_pred_event = y_pred[filt]
if args.beta_type == "default":
betas = y_pred_event[filt, 3]
elif args.beta_type == "pt":
betas = event.pfcands.pt
elif args.beta_type == "pt+bc":
betas = event.pfcands.pt
classifier_labels = y_pred_event[:, 3]
p_xyz = y_pred_event[:, :3]
if y_pred_event.shape[1] == 5:
p_xyz = y_pred_event[:, 1:4]
e = y_pred_event[:, 0]
#lorentz_invariant = e**2 - p_xyz.norm(dim=1)**2
#lorentz_invariant_inputs = event.pfcands.E ** 2 - event.pfcands.pxyz.norm(dim=1) ** 2
plot_coordinates(event.pfcands.pxyz, pt=event.pfcands.pt, tidx=y_true_event,
outdir=os.path.dirname(filename),
filename="input_coords_batch_" + str(batch) + "_event_" + str(i) + ".html")
plot_coordinates(p_xyz, pt=event.pfcands.pt, tidx=y_true_event,
outdir=os.path.dirname(filename),
filename="model_coords_batch_" + str(batch) + "_event_" + str(i) + ".html")
y_true_event = y_true_event.tolist()
clist = ['#1f78b4', '#b3df8a', '#33a02c', '#fb9a99', '#e31a1c', '#fdbe6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928']
colors = {
-1: "gray",
0: clist[0],
1: clist[1],
2: clist[2],
3: clist[3]
}
eta, phi = calc_eta_phi(p_xyz, return_stacked=False)
plot_event(event, colors=plt.cm.brg(betas), ax=ax[i, 0])
cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=plt.cm.brg), ax=ax[i, 0]) # How to specify the palette?
ax[i, 0].set_title(r"input coords, $\beta$ colors")
cbar.set_label(r"$\beta$")
plot_event(event, colors=[colors[i] for i in y_true_event], ax=ax[i, 1])
ax[i, 1].set_title("input coords, GT colors")
plot_event(event, custom_coords=[eta, phi], colors=plt.cm.brg(betas), ax=ax[i, 2], jets=False)
#assert betas.min() >= 0 and betas.max() <= 1
ax[i, 2].set_title(r"model coords, $\beta$ colors")
cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=plt.cm.brg), ax=ax[i, 2])
ax[i, 3].set_title("model coords, GT colors")
cbar.set_label(r"$\beta$")
plot_event(event, custom_coords=[eta, phi], colors=[colors[i] for i in y_true_event], ax=ax[i, 3], jets=False)
if args.beta_type == "pt+bc":
# Create a custom colormap from light gray to dark green
colors = [(0.9, 0.9, 0.9), (0.0, 0.5, 0.0)] # RGB for light gray and dark green
cmap_name = "lightgray_to_darkgreen"
custom_cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors)
plot_event(event, custom_coords=[eta, phi], colors=custom_cmap(classifier_labels), ax=ax[i, 5], jets=False)
ax[i, 5].set_title(r"model coords, BC label colors")
cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=custom_cmap), ax=ax[i, 5])
cbar.set_label("Classifier score")
plot_event(event, colors=custom_cmap(classifier_labels), ax=ax[i, 4], jets=False)
ax[i, 4].set_title(r"input coords, BC label colors")
cbar = plt.colorbar(mappable=cm.ScalarMappable(cmap=custom_cmap), ax=ax[i, 4])
cbar.set_label("Classifier score")
print("Saving eval figure to", filename)
fig.tight_layout()
fig.savefig(filename)
fig.clear()
plt.close(fig)