Spaces:
Sleeping
Sleeping
import wandb | |
import numpy as np | |
import torch | |
from sklearn.metrics import roc_curve, roc_auc_score | |
import json | |
import dgl | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
from torch_scatter import scatter_max | |
from matplotlib.cm import ScalarMappable | |
from matplotlib.colors import Normalize | |
def log_wandb_init(args, data_config={}): | |
"""log information about the run in the config section of wandb | |
Currently wandb is only initialized in training mode | |
Args: | |
args (_type_): parsed args from training | |
""" | |
if args.regression_mode: | |
wandb.config.regression_mode = True | |
else: | |
wandb.config.classification_mode = True | |
wandb.config.num_epochs = args.num_epochs | |
wandb.config.args = vars(args) | |
wandb.config.graph_config = data_config.graph_config | |
wandb.config.custom_model_kwargs = data_config.custom_model_kwargs | |
def log_confussion_matrix_wandb(y_true, y_score, epoch): | |
"""function to log confussion matrix in the wandb.ai website | |
Args: | |
y_true (_type_): labels (B,) | |
y_score (_type_): probabilities (B,num_classes) | |
epoch (_type_): epoch of training so that maybe we can build slider in wandb | |
""" | |
if y_score.ndim == 1: | |
y_pred = y_score > 0.5 | |
else: | |
y_pred = y_score.argmax(1) | |
cm = wandb.plot.confusion_matrix(y_score, y_true=y_true) | |
wandb.log({"confussion matrix": cm}) | |
# we could also log multiple cm during training but no sliding for now. | |
def log_roc_curves(y_true, y_score, epoch): | |
# 5 classes G(0),Q(1),S(2),C(3),B(4) | |
# b tagging (b/g, b/ud, b/c) | |
_bg = create_binary_rocs(4, 0, y_true, y_score) | |
_bud = create_binary_rocs(4, 1, y_true, y_score) | |
_bc = create_binary_rocs(4, 3, y_true, y_score) | |
if len(_bg) > 0 and len(_bud) > 0 and len(_bc) > 0: | |
# this if checks if all elements are not of the same class | |
calculate_and_log_tpr_1_10_percent(_bg[0], _bg[1], "b", "g") | |
calculate_and_log_tpr_1_10_percent(_bud[0], _bud[1], "b", "ud") | |
calculate_and_log_tpr_1_10_percent(_bc[0], _bc[1], "b", "c") | |
columns = ["b vs g", "b vs ud", "b vs c"] | |
xs = [_bg[1], _bud[1], _bc[1]] | |
ys = [_bg[0], _bud[0], _bc[0]] | |
auc_ = [_bg[2], _bud[2], _bc[2]] | |
title_log = "roc b" | |
title_plot = "b tagging" | |
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns) | |
wandb_log_auc(auc_, ["b_g", "b_ud", "b_c"]) | |
else: | |
print("all batch from the same class in b", len(_bg), len(_bud), len(_bc)) | |
# c tagging (c/g, c/ud, c/b) | |
_cg = create_binary_rocs(3, 0, y_true, y_score) | |
_cud = create_binary_rocs(3, 1, y_true, y_score) | |
_cb = create_binary_rocs(3, 4, y_true, y_score) | |
if len(_cg) > 0 and len(_cud) > 0 and len(_cb) > 0: | |
calculate_and_log_tpr_1_10_percent(_cg[0], _cg[1], "c", "g") | |
calculate_and_log_tpr_1_10_percent(_cud[0], _cud[1], "c", "ud") | |
calculate_and_log_tpr_1_10_percent(_cb[0], _cb[1], "c", "b") | |
columns = ["c vs g", "c vs ud", "c vs b"] | |
xs = [_cg[1], _cud[1], _cb[1]] | |
ys = [_cg[0], _cud[0], _cb[0]] | |
auc_ = [_cg[2], _cud[2], _cb[2]] | |
title_log = "roc c" | |
title_plot = "c tagging" | |
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns) | |
wandb_log_auc(auc_, ["c_g", "c_ud", "c_b"]) | |
else: | |
print("all batch from the same class in c", len(_cg), len(_cud), len(_cb)) | |
# s tagging (s/g, s/ud, s/c, s/b) | |
_sg = create_binary_rocs(2, 0, y_true, y_score) | |
_sud = create_binary_rocs(2, 1, y_true, y_score) | |
_sc = create_binary_rocs(2, 3, y_true, y_score) | |
_sb = create_binary_rocs(2, 4, y_true, y_score) | |
if len(_sg) > 0 and len(_sud) > 0 and len(_sc) > 0 and len(_sb) > 0: | |
calculate_and_log_tpr_1_10_percent(_sg[0], _sg[1], "s", "g") | |
calculate_and_log_tpr_1_10_percent(_sud[0], _sud[1], "s", "ud") | |
calculate_and_log_tpr_1_10_percent(_sc[0], _sc[1], "s", "c") | |
calculate_and_log_tpr_1_10_percent(_sb[0], _sb[1], "s", "b") | |
columns = ["s vs g", "s vs ud", "s vs c", "s vs b"] | |
xs = [_sg[1], _sud[1], _sc[1], _sb[1]] | |
ys = [_sg[0], _sud[0], _sc[0], _sb[0]] | |
auc_ = [_sg[2], _sud[2], _sb[2]] | |
title_log = "roc s" | |
title_plot = "s tagging" | |
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns) | |
wandb_log_auc(auc_, ["s_g", "s_ud", "s_c", "s_b"]) | |
else: | |
print( | |
"all batch from the same class in s", | |
len(_sg), | |
len(_sud), | |
len(_sc), | |
len(_sb), | |
) | |
# g tagging (g/ud, g/s, g/c, g/b) | |
_gud = create_binary_rocs(0, 1, y_true, y_score) | |
_gs = create_binary_rocs(0, 2, y_true, y_score) | |
_gc = create_binary_rocs(0, 3, y_true, y_score) | |
_gb = create_binary_rocs(0, 4, y_true, y_score) | |
if len(_gud) > 0 and len(_gs) > 0 and len(_gc) > 0 and len(_gb) > 0: | |
calculate_and_log_tpr_1_10_percent(_gud[0], _gud[1], "g", "ud") | |
calculate_and_log_tpr_1_10_percent(_gs[0], _gs[1], "g", "s") | |
calculate_and_log_tpr_1_10_percent(_gc[0], _gc[1], "g", "c") | |
calculate_and_log_tpr_1_10_percent(_gb[0], _gb[1], "g", "b") | |
columns = ["g vs ud", "g vs s", "g vs c", "g vs b"] | |
xs = [_gud[1], _gs[1], _gc[1], _gb[1]] | |
ys = [_gud[0], _gs[0], _gc[0], _gb[0]] | |
auc_ = [_gud[2], _gs[2], _gc[2], _gb[2]] | |
title_log = "roc g" | |
title_plot = "g tagging" | |
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns) | |
wandb_log_auc(auc_, ["g_ud", "g_s", "g_c", "g_b"]) | |
else: | |
print( | |
"all batch from the same class in g", | |
len(_gud), | |
len(_gs), | |
len(_gc), | |
len(_gb), | |
) | |
# def tagging_at_xpercent_misstag(): | |
def log_histograms(y_true_wandb, scores_wandb, counts_particles, epoch): | |
print("logging hist func") | |
y_pred = np.argmax(scores_wandb, axis=1) | |
errors_class_examples = y_true_wandb != y_pred | |
correct_class_examples = y_true_wandb == y_pred | |
errors_number_count = counts_particles[errors_class_examples] | |
correct_number_count = counts_particles[correct_class_examples] | |
#print("count", errors_number_count.shape, correct_number_count.shape) | |
data_correct = [ | |
[i, correct_number_count[i]] for i in range(0, len(correct_number_count)) | |
] | |
data_errors = [ | |
[i, errors_number_count[i]] for i in range(0, len(errors_number_count)) | |
] | |
table_correct = wandb.Table( | |
data=data_correct, columns=["IDs", "correct_number_count"] | |
) | |
table_errors = wandb.Table(data=data_errors, columns=["IDs", "errors_number_count"]) | |
wandb.log({"hist_errors_count": wandb.Histogram(errors_number_count)}) | |
# wandb.log({'hist_errors_count': wandb.plot.histogram(table_errors, "errors_number_count", | |
# title="Histogram errors number const")}) | |
def wandb_log_auc(auc_, names): | |
for i in range(0, len(auc_)): | |
name = "auc/" + names[i] | |
# logging 1-auc because we are looking at roc with flipped axis | |
wandb.log({name: 1 - auc_[i]}) | |
return auc_ | |
def wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns): | |
ys_log = [np.log10(j + 1e-8) for j in ys] | |
wandb.log( | |
{ | |
title_log: wandb.plot.line_series( | |
xs=xs, | |
ys=ys_log, | |
keys=columns, | |
title=title_plot, | |
xname="jet tagging efficiency", | |
) | |
} | |
) | |
def find_nearest(a, a0): | |
"Element in nd array `a` closest to the scalar value `a0`" | |
idx = np.abs(a - a0).argmin() | |
return idx | |
def create_binary_rocs(positive, negative, y_true, y_score): | |
mask_positive = y_true == positive | |
mask_negative = y_true == negative | |
# print(y_true.shape, np.sum(mask_positive), np.sum(mask_negative), positive, negative) | |
number_positive = len(y_true[mask_positive]) | |
number_negative = len(y_true[mask_negative]) | |
if number_positive > 0 and number_negative > 0: | |
# print('s',positive,negative,number_positive,number_negative) | |
y_true_positive = torch.reshape(torch.ones([number_positive]), (-1,)) | |
y_true_negative = torch.reshape(torch.zeros([number_negative]), (-1,)) | |
y_true_ = torch.cat((y_true_positive, y_true_negative), dim=0) | |
y_score_positive = torch.tensor(y_score[mask_positive]) | |
y_score_negative = torch.tensor(y_score[mask_negative]) | |
indices = torch.tensor([negative, positive]) | |
y_score_positive_ = torch.index_select(y_score_positive, 1, indices) | |
y_score_negative_ = torch.index_select(y_score_negative, 1, indices) | |
y_scores_pos_prob = torch.exp(y_score_positive_) / torch.sum( | |
torch.exp(y_score_positive_), keepdim=True, dim=1 | |
) | |
y_scores_neg_prob = torch.exp(y_score_negative_) / torch.sum( | |
torch.exp(y_score_negative_), keepdim=True, dim=1 | |
) | |
y_prob_positiveclass = torch.reshape(y_scores_pos_prob[:, 1], (-1,)) | |
y_prob_positiveclass_neg = torch.reshape(y_scores_neg_prob[:, 1], (-1,)) | |
y_prob_positive = torch.cat( | |
(y_prob_positiveclass, y_prob_positiveclass_neg), dim=0 | |
) | |
fpr, tpr, thrs = roc_curve( | |
y_true_.numpy(), y_prob_positive.numpy(), pos_label=1 | |
) | |
auc_score = roc_auc_score(y_true_.numpy(), y_prob_positive.numpy()) | |
return [fpr, tpr, auc_score] | |
else: | |
return [] | |
def calculate_and_log_tpr_1_10_percent(fpr, tpr, name_pos, name_neg): | |
idx_10_percent = find_nearest(fpr, 0.1) | |
idx_1_percent = find_nearest(fpr, 0.01) | |
tpr_10_percent = tpr[idx_10_percent] | |
tpr_1_percent = tpr[idx_1_percent] | |
name_10 = "te/" + name_pos + "_vs_" + name_neg + "_10%" | |
name_1 = "te/" + name_pos + "_vs_" + name_neg + "_1%" | |
wandb.log({name_10: tpr_10_percent, name_1: tpr_1_percent}) | |
def plot_clust(g, q, xj, title_prefix="", y=None, radius=None, betas=None, loss_e_frac=None): | |
graph_list = dgl.unbatch(g) | |
node_counter = 0 | |
particle_counter = 0 | |
fig, ax = plt.subplots(12, 10, figsize=(33, 40)) | |
for i in range(0, min(12, len(graph_list))): | |
graph_eval = graph_list[i] | |
# print([g.num_nodes() for g in graph_list]) | |
non = graph_eval.number_of_nodes() | |
assert non == graph_eval.ndata["h"].shape[0] | |
n_part = graph_eval.ndata["particle_number"].max().long().item() | |
particle_number = graph_eval.ndata["particle_number"] | |
# if particle_number.max() > 1: | |
# print("skipping one, only plotting events with 2 particles") | |
# continue | |
q_graph = q[node_counter : node_counter + non].flatten() | |
if betas != None: | |
beta_graph = betas[node_counter : node_counter + non].flatten() | |
hit_type = torch.argmax(graph_eval.ndata["hit_type"], dim=1).view(-1) | |
part_num = graph_eval.ndata["particle_number"].view(-1).to(torch.long) | |
q_alpha, index_alpha = scatter_max( | |
q_graph.cpu().view(-1), part_num.cpu() - 1 | |
) | |
# print(part_num.unique()) | |
xj_graph = xj[node_counter : node_counter + non, :].detach().cpu() | |
if len(index_alpha) == 1: | |
index_alpha = index_alpha.item() | |
clr = graph_eval.ndata["particle_number"] | |
ax[i, 2].set_title("x and y of hits") | |
xhits, yhits = ( | |
graph_eval.ndata["h"][:, 0].detach().cpu(), | |
graph_eval.ndata["h"][:, 1].detach().cpu(), | |
) | |
hittype = torch.argmax(graph_eval.ndata["h"][:, [3, 4, 5, 6]], dim=1).view( | |
-1 | |
) | |
clr_energy = torch.log10(graph_eval.ndata["h"][:, 7].detach().cpu()) | |
ax[i, 2].scatter(xhits, yhits, c=clr.tolist(), alpha=0.2) | |
ax[i, 3].scatter(xhits, yhits, c=clr_energy.tolist(), alpha=0.2) | |
ax[i, 3].set_title("x and y of hits colored by log10 energy") | |
ax[i, 4].scatter(xhits, yhits, c=hittype.tolist(), alpha=0.2) | |
ax[i, 4].set_title("x and y of hits colored by hit type (ecal/hcal)") | |
if betas != None: | |
ax[i, 5].scatter(xhits, yhits, c=beta_graph.detach().cpu(), alpha=0.2) | |
ax[i, 5].set_title("hits coloored by beta") | |
fig.colorbar( | |
ScalarMappable(norm=Normalize(vmin=0, vmax=1)), ax=ax[i, 5] | |
).set_label("beta") | |
ax[i, 6].hist(beta_graph.detach().cpu(), bins=100, range=(0, 1)) | |
ax[i, 6].set_title("beta distr.") | |
fig.colorbar( | |
ScalarMappable(norm=Normalize(vmin=0.5, vmax=1)), ax=ax[i, 7] | |
).set_label("beta > 0.5") | |
no_objects = len(np.unique(part_num.cpu())) | |
ax[i, 7].scatter( | |
xj_graph[:, 0][beta_graph.detach().cpu() > 0.5], | |
xj_graph[:, 1][beta_graph.detach().cpu() > 0.5], | |
c=beta_graph[beta_graph.detach().cpu() > 0.5].detach().cpu(), | |
alpha=0.2 | |
) | |
# plot no_objects highest betas | |
index_highest = np.argsort(beta_graph.detach().cpu())[-no_objects:] | |
ax[i, 7].scatter( | |
xj_graph[:, 0][index_highest], | |
xj_graph[:, 1][index_highest], | |
marker="*", | |
c="red" | |
) | |
ax[i, 7].set_title("hits with beta > 0.5") | |
ax[i, 8].set_title("hits of particles that have a low loss_e_frac") | |
if loss_e_frac is not None: | |
if not isinstance(loss_e_frac, torch.Tensor): | |
loss_e_frac = torch.cat(loss_e_frac) | |
loss_e_frac_batch = loss_e_frac[particle_counter : particle_counter + n_part] | |
particle_counter += n_part | |
low_filter = torch.nonzero(loss_e_frac_batch < 0.05).flatten() | |
if not len(low_filter): | |
continue | |
ax[i, 8].set_title(loss_e_frac_batch[low_filter[0]]) | |
particle_number_low = part_num[low_filter[0]] | |
# filter to particle numbers contained in particle_number_low | |
low_filter = torch.nonzero(part_num == particle_number_low).flatten().detach().cpu() | |
ax[i, 8].scatter( | |
xj_graph[:, 0], | |
xj_graph[:, 1], | |
c="gray", | |
alpha=0.2 | |
) | |
ax[i, 9].scatter( | |
xj_graph[:, 0], | |
xj_graph[:, 2], | |
c="gray", | |
alpha=0.2 | |
) | |
ax[i, 8].set_xlabel("X") | |
ax[i, 8].set_ylabel("Y") | |
ax[i, 9].set_xlabel("X") | |
ax[i, 9].set_ylabel("Z") | |
ax[i, 8].scatter( | |
xj_graph[:, 0][low_filter], | |
xj_graph[:, 1][low_filter], | |
c="blue", | |
alpha=0.2 | |
) | |
ax[i, 9].scatter( | |
xj_graph[:, 0][low_filter], | |
xj_graph[:, 2][low_filter], | |
c="blue", | |
alpha=0.2 | |
) | |
ia1 = torch.zeros(xj_graph.shape[0]).long() | |
ia2 = torch.zeros_like(ia1) | |
ia1[index_alpha] = 1. | |
ia2[low_filter] = 1. | |
ax[i, 8].scatter( | |
xj_graph[ia1, 0], | |
xj_graph[ia1, 1], | |
marker="*", | |
c="r", | |
alpha=1.0, | |
) | |
ax[i, 8].scatter( | |
xj_graph[ia1*ia2, 0], | |
xj_graph[ia1*ia2, 1], | |
marker="*", | |
c="g", | |
alpha=1.0, | |
) | |
ax[i, 9].scatter( | |
xj_graph[ia1, 0], | |
xj_graph[ia1, 2], | |
marker="*", | |
c="r", | |
alpha=1.0, | |
) | |
ax[i, 9].scatter( | |
xj_graph[ia1 * ia2, 0], | |
xj_graph[ia1 * ia2, 2], | |
marker="*", | |
c="g", | |
alpha=1.0, | |
) | |
ax[i, 0].set_title( | |
title_prefix | |
+ " " | |
+ str(np.unique(part_num.cpu())) | |
+ " " | |
+ str(len(np.unique(part_num.cpu()))) | |
) | |
ax[i, 1].set_title("PCA of node features") | |
ax[i, 0].scatter(xj_graph[:, 0], xj_graph[:, 1], c=clr.tolist(), alpha=0.2) | |
if non > 1: | |
PCA_2d_node_feats = PCA(n_components=2).fit_transform( | |
graph_eval.ndata["h"].detach().cpu().numpy() | |
) | |
ax[i, 1].scatter( | |
PCA_2d_node_feats[:, 0], | |
PCA_2d_node_feats[:, 1], | |
c=clr.tolist(), | |
alpha=0.2, | |
) | |
ax[i, 0].scatter( | |
xj_graph[index_alpha, 0], | |
xj_graph[index_alpha, 1], | |
marker="*", | |
c="r", | |
alpha=1.0, | |
) | |
pos = graph_eval.ndata["pos_hits_norm"] | |
node_counter += non | |
return fig, ax | |