jetclustering / src /utils /logger_wandb.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
17.1 kB
import wandb
import numpy as np
import torch
from sklearn.metrics import roc_curve, roc_auc_score
import json
import dgl
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from torch_scatter import scatter_max
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
def log_wandb_init(args, data_config={}):
"""log information about the run in the config section of wandb
Currently wandb is only initialized in training mode
Args:
args (_type_): parsed args from training
"""
if args.regression_mode:
wandb.config.regression_mode = True
else:
wandb.config.classification_mode = True
wandb.config.num_epochs = args.num_epochs
wandb.config.args = vars(args)
wandb.config.graph_config = data_config.graph_config
wandb.config.custom_model_kwargs = data_config.custom_model_kwargs
def log_confussion_matrix_wandb(y_true, y_score, epoch):
"""function to log confussion matrix in the wandb.ai website
Args:
y_true (_type_): labels (B,)
y_score (_type_): probabilities (B,num_classes)
epoch (_type_): epoch of training so that maybe we can build slider in wandb
"""
if y_score.ndim == 1:
y_pred = y_score > 0.5
else:
y_pred = y_score.argmax(1)
cm = wandb.plot.confusion_matrix(y_score, y_true=y_true)
wandb.log({"confussion matrix": cm})
# we could also log multiple cm during training but no sliding for now.
def log_roc_curves(y_true, y_score, epoch):
# 5 classes G(0),Q(1),S(2),C(3),B(4)
# b tagging (b/g, b/ud, b/c)
_bg = create_binary_rocs(4, 0, y_true, y_score)
_bud = create_binary_rocs(4, 1, y_true, y_score)
_bc = create_binary_rocs(4, 3, y_true, y_score)
if len(_bg) > 0 and len(_bud) > 0 and len(_bc) > 0:
# this if checks if all elements are not of the same class
calculate_and_log_tpr_1_10_percent(_bg[0], _bg[1], "b", "g")
calculate_and_log_tpr_1_10_percent(_bud[0], _bud[1], "b", "ud")
calculate_and_log_tpr_1_10_percent(_bc[0], _bc[1], "b", "c")
columns = ["b vs g", "b vs ud", "b vs c"]
xs = [_bg[1], _bud[1], _bc[1]]
ys = [_bg[0], _bud[0], _bc[0]]
auc_ = [_bg[2], _bud[2], _bc[2]]
title_log = "roc b"
title_plot = "b tagging"
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns)
wandb_log_auc(auc_, ["b_g", "b_ud", "b_c"])
else:
print("all batch from the same class in b", len(_bg), len(_bud), len(_bc))
# c tagging (c/g, c/ud, c/b)
_cg = create_binary_rocs(3, 0, y_true, y_score)
_cud = create_binary_rocs(3, 1, y_true, y_score)
_cb = create_binary_rocs(3, 4, y_true, y_score)
if len(_cg) > 0 and len(_cud) > 0 and len(_cb) > 0:
calculate_and_log_tpr_1_10_percent(_cg[0], _cg[1], "c", "g")
calculate_and_log_tpr_1_10_percent(_cud[0], _cud[1], "c", "ud")
calculate_and_log_tpr_1_10_percent(_cb[0], _cb[1], "c", "b")
columns = ["c vs g", "c vs ud", "c vs b"]
xs = [_cg[1], _cud[1], _cb[1]]
ys = [_cg[0], _cud[0], _cb[0]]
auc_ = [_cg[2], _cud[2], _cb[2]]
title_log = "roc c"
title_plot = "c tagging"
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns)
wandb_log_auc(auc_, ["c_g", "c_ud", "c_b"])
else:
print("all batch from the same class in c", len(_cg), len(_cud), len(_cb))
# s tagging (s/g, s/ud, s/c, s/b)
_sg = create_binary_rocs(2, 0, y_true, y_score)
_sud = create_binary_rocs(2, 1, y_true, y_score)
_sc = create_binary_rocs(2, 3, y_true, y_score)
_sb = create_binary_rocs(2, 4, y_true, y_score)
if len(_sg) > 0 and len(_sud) > 0 and len(_sc) > 0 and len(_sb) > 0:
calculate_and_log_tpr_1_10_percent(_sg[0], _sg[1], "s", "g")
calculate_and_log_tpr_1_10_percent(_sud[0], _sud[1], "s", "ud")
calculate_and_log_tpr_1_10_percent(_sc[0], _sc[1], "s", "c")
calculate_and_log_tpr_1_10_percent(_sb[0], _sb[1], "s", "b")
columns = ["s vs g", "s vs ud", "s vs c", "s vs b"]
xs = [_sg[1], _sud[1], _sc[1], _sb[1]]
ys = [_sg[0], _sud[0], _sc[0], _sb[0]]
auc_ = [_sg[2], _sud[2], _sb[2]]
title_log = "roc s"
title_plot = "s tagging"
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns)
wandb_log_auc(auc_, ["s_g", "s_ud", "s_c", "s_b"])
else:
print(
"all batch from the same class in s",
len(_sg),
len(_sud),
len(_sc),
len(_sb),
)
# g tagging (g/ud, g/s, g/c, g/b)
_gud = create_binary_rocs(0, 1, y_true, y_score)
_gs = create_binary_rocs(0, 2, y_true, y_score)
_gc = create_binary_rocs(0, 3, y_true, y_score)
_gb = create_binary_rocs(0, 4, y_true, y_score)
if len(_gud) > 0 and len(_gs) > 0 and len(_gc) > 0 and len(_gb) > 0:
calculate_and_log_tpr_1_10_percent(_gud[0], _gud[1], "g", "ud")
calculate_and_log_tpr_1_10_percent(_gs[0], _gs[1], "g", "s")
calculate_and_log_tpr_1_10_percent(_gc[0], _gc[1], "g", "c")
calculate_and_log_tpr_1_10_percent(_gb[0], _gb[1], "g", "b")
columns = ["g vs ud", "g vs s", "g vs c", "g vs b"]
xs = [_gud[1], _gs[1], _gc[1], _gb[1]]
ys = [_gud[0], _gs[0], _gc[0], _gb[0]]
auc_ = [_gud[2], _gs[2], _gc[2], _gb[2]]
title_log = "roc g"
title_plot = "g tagging"
wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns)
wandb_log_auc(auc_, ["g_ud", "g_s", "g_c", "g_b"])
else:
print(
"all batch from the same class in g",
len(_gud),
len(_gs),
len(_gc),
len(_gb),
)
# def tagging_at_xpercent_misstag():
def log_histograms(y_true_wandb, scores_wandb, counts_particles, epoch):
print("logging hist func")
y_pred = np.argmax(scores_wandb, axis=1)
errors_class_examples = y_true_wandb != y_pred
correct_class_examples = y_true_wandb == y_pred
errors_number_count = counts_particles[errors_class_examples]
correct_number_count = counts_particles[correct_class_examples]
#print("count", errors_number_count.shape, correct_number_count.shape)
data_correct = [
[i, correct_number_count[i]] for i in range(0, len(correct_number_count))
]
data_errors = [
[i, errors_number_count[i]] for i in range(0, len(errors_number_count))
]
table_correct = wandb.Table(
data=data_correct, columns=["IDs", "correct_number_count"]
)
table_errors = wandb.Table(data=data_errors, columns=["IDs", "errors_number_count"])
wandb.log({"hist_errors_count": wandb.Histogram(errors_number_count)})
# wandb.log({'hist_errors_count': wandb.plot.histogram(table_errors, "errors_number_count",
# title="Histogram errors number const")})
def wandb_log_auc(auc_, names):
for i in range(0, len(auc_)):
name = "auc/" + names[i]
# logging 1-auc because we are looking at roc with flipped axis
wandb.log({name: 1 - auc_[i]})
return auc_
def wandb_log_multiline_rocs(xs, ys, title_log, title_plot, columns):
ys_log = [np.log10(j + 1e-8) for j in ys]
wandb.log(
{
title_log: wandb.plot.line_series(
xs=xs,
ys=ys_log,
keys=columns,
title=title_plot,
xname="jet tagging efficiency",
)
}
)
def find_nearest(a, a0):
"Element in nd array `a` closest to the scalar value `a0`"
idx = np.abs(a - a0).argmin()
return idx
def create_binary_rocs(positive, negative, y_true, y_score):
mask_positive = y_true == positive
mask_negative = y_true == negative
# print(y_true.shape, np.sum(mask_positive), np.sum(mask_negative), positive, negative)
number_positive = len(y_true[mask_positive])
number_negative = len(y_true[mask_negative])
if number_positive > 0 and number_negative > 0:
# print('s',positive,negative,number_positive,number_negative)
y_true_positive = torch.reshape(torch.ones([number_positive]), (-1,))
y_true_negative = torch.reshape(torch.zeros([number_negative]), (-1,))
y_true_ = torch.cat((y_true_positive, y_true_negative), dim=0)
y_score_positive = torch.tensor(y_score[mask_positive])
y_score_negative = torch.tensor(y_score[mask_negative])
indices = torch.tensor([negative, positive])
y_score_positive_ = torch.index_select(y_score_positive, 1, indices)
y_score_negative_ = torch.index_select(y_score_negative, 1, indices)
y_scores_pos_prob = torch.exp(y_score_positive_) / torch.sum(
torch.exp(y_score_positive_), keepdim=True, dim=1
)
y_scores_neg_prob = torch.exp(y_score_negative_) / torch.sum(
torch.exp(y_score_negative_), keepdim=True, dim=1
)
y_prob_positiveclass = torch.reshape(y_scores_pos_prob[:, 1], (-1,))
y_prob_positiveclass_neg = torch.reshape(y_scores_neg_prob[:, 1], (-1,))
y_prob_positive = torch.cat(
(y_prob_positiveclass, y_prob_positiveclass_neg), dim=0
)
fpr, tpr, thrs = roc_curve(
y_true_.numpy(), y_prob_positive.numpy(), pos_label=1
)
auc_score = roc_auc_score(y_true_.numpy(), y_prob_positive.numpy())
return [fpr, tpr, auc_score]
else:
return []
def calculate_and_log_tpr_1_10_percent(fpr, tpr, name_pos, name_neg):
idx_10_percent = find_nearest(fpr, 0.1)
idx_1_percent = find_nearest(fpr, 0.01)
tpr_10_percent = tpr[idx_10_percent]
tpr_1_percent = tpr[idx_1_percent]
name_10 = "te/" + name_pos + "_vs_" + name_neg + "_10%"
name_1 = "te/" + name_pos + "_vs_" + name_neg + "_1%"
wandb.log({name_10: tpr_10_percent, name_1: tpr_1_percent})
def plot_clust(g, q, xj, title_prefix="", y=None, radius=None, betas=None, loss_e_frac=None):
graph_list = dgl.unbatch(g)
node_counter = 0
particle_counter = 0
fig, ax = plt.subplots(12, 10, figsize=(33, 40))
for i in range(0, min(12, len(graph_list))):
graph_eval = graph_list[i]
# print([g.num_nodes() for g in graph_list])
non = graph_eval.number_of_nodes()
assert non == graph_eval.ndata["h"].shape[0]
n_part = graph_eval.ndata["particle_number"].max().long().item()
particle_number = graph_eval.ndata["particle_number"]
# if particle_number.max() > 1:
# print("skipping one, only plotting events with 2 particles")
# continue
q_graph = q[node_counter : node_counter + non].flatten()
if betas != None:
beta_graph = betas[node_counter : node_counter + non].flatten()
hit_type = torch.argmax(graph_eval.ndata["hit_type"], dim=1).view(-1)
part_num = graph_eval.ndata["particle_number"].view(-1).to(torch.long)
q_alpha, index_alpha = scatter_max(
q_graph.cpu().view(-1), part_num.cpu() - 1
)
# print(part_num.unique())
xj_graph = xj[node_counter : node_counter + non, :].detach().cpu()
if len(index_alpha) == 1:
index_alpha = index_alpha.item()
clr = graph_eval.ndata["particle_number"]
ax[i, 2].set_title("x and y of hits")
xhits, yhits = (
graph_eval.ndata["h"][:, 0].detach().cpu(),
graph_eval.ndata["h"][:, 1].detach().cpu(),
)
hittype = torch.argmax(graph_eval.ndata["h"][:, [3, 4, 5, 6]], dim=1).view(
-1
)
clr_energy = torch.log10(graph_eval.ndata["h"][:, 7].detach().cpu())
ax[i, 2].scatter(xhits, yhits, c=clr.tolist(), alpha=0.2)
ax[i, 3].scatter(xhits, yhits, c=clr_energy.tolist(), alpha=0.2)
ax[i, 3].set_title("x and y of hits colored by log10 energy")
ax[i, 4].scatter(xhits, yhits, c=hittype.tolist(), alpha=0.2)
ax[i, 4].set_title("x and y of hits colored by hit type (ecal/hcal)")
if betas != None:
ax[i, 5].scatter(xhits, yhits, c=beta_graph.detach().cpu(), alpha=0.2)
ax[i, 5].set_title("hits coloored by beta")
fig.colorbar(
ScalarMappable(norm=Normalize(vmin=0, vmax=1)), ax=ax[i, 5]
).set_label("beta")
ax[i, 6].hist(beta_graph.detach().cpu(), bins=100, range=(0, 1))
ax[i, 6].set_title("beta distr.")
fig.colorbar(
ScalarMappable(norm=Normalize(vmin=0.5, vmax=1)), ax=ax[i, 7]
).set_label("beta > 0.5")
no_objects = len(np.unique(part_num.cpu()))
ax[i, 7].scatter(
xj_graph[:, 0][beta_graph.detach().cpu() > 0.5],
xj_graph[:, 1][beta_graph.detach().cpu() > 0.5],
c=beta_graph[beta_graph.detach().cpu() > 0.5].detach().cpu(),
alpha=0.2
)
# plot no_objects highest betas
index_highest = np.argsort(beta_graph.detach().cpu())[-no_objects:]
ax[i, 7].scatter(
xj_graph[:, 0][index_highest],
xj_graph[:, 1][index_highest],
marker="*",
c="red"
)
ax[i, 7].set_title("hits with beta > 0.5")
ax[i, 8].set_title("hits of particles that have a low loss_e_frac")
if loss_e_frac is not None:
if not isinstance(loss_e_frac, torch.Tensor):
loss_e_frac = torch.cat(loss_e_frac)
loss_e_frac_batch = loss_e_frac[particle_counter : particle_counter + n_part]
particle_counter += n_part
low_filter = torch.nonzero(loss_e_frac_batch < 0.05).flatten()
if not len(low_filter):
continue
ax[i, 8].set_title(loss_e_frac_batch[low_filter[0]])
particle_number_low = part_num[low_filter[0]]
# filter to particle numbers contained in particle_number_low
low_filter = torch.nonzero(part_num == particle_number_low).flatten().detach().cpu()
ax[i, 8].scatter(
xj_graph[:, 0],
xj_graph[:, 1],
c="gray",
alpha=0.2
)
ax[i, 9].scatter(
xj_graph[:, 0],
xj_graph[:, 2],
c="gray",
alpha=0.2
)
ax[i, 8].set_xlabel("X")
ax[i, 8].set_ylabel("Y")
ax[i, 9].set_xlabel("X")
ax[i, 9].set_ylabel("Z")
ax[i, 8].scatter(
xj_graph[:, 0][low_filter],
xj_graph[:, 1][low_filter],
c="blue",
alpha=0.2
)
ax[i, 9].scatter(
xj_graph[:, 0][low_filter],
xj_graph[:, 2][low_filter],
c="blue",
alpha=0.2
)
ia1 = torch.zeros(xj_graph.shape[0]).long()
ia2 = torch.zeros_like(ia1)
ia1[index_alpha] = 1.
ia2[low_filter] = 1.
ax[i, 8].scatter(
xj_graph[ia1, 0],
xj_graph[ia1, 1],
marker="*",
c="r",
alpha=1.0,
)
ax[i, 8].scatter(
xj_graph[ia1*ia2, 0],
xj_graph[ia1*ia2, 1],
marker="*",
c="g",
alpha=1.0,
)
ax[i, 9].scatter(
xj_graph[ia1, 0],
xj_graph[ia1, 2],
marker="*",
c="r",
alpha=1.0,
)
ax[i, 9].scatter(
xj_graph[ia1 * ia2, 0],
xj_graph[ia1 * ia2, 2],
marker="*",
c="g",
alpha=1.0,
)
ax[i, 0].set_title(
title_prefix
+ " "
+ str(np.unique(part_num.cpu()))
+ " "
+ str(len(np.unique(part_num.cpu())))
)
ax[i, 1].set_title("PCA of node features")
ax[i, 0].scatter(xj_graph[:, 0], xj_graph[:, 1], c=clr.tolist(), alpha=0.2)
if non > 1:
PCA_2d_node_feats = PCA(n_components=2).fit_transform(
graph_eval.ndata["h"].detach().cpu().numpy()
)
ax[i, 1].scatter(
PCA_2d_node_feats[:, 0],
PCA_2d_node_feats[:, 1],
c=clr.tolist(),
alpha=0.2,
)
ax[i, 0].scatter(
xj_graph[index_alpha, 0],
xj_graph[index_alpha, 1],
marker="*",
c="r",
alpha=1.0,
)
pos = graph_eval.ndata["pos_hits_norm"]
node_counter += non
return fig, ax