jetclustering / src /utils /nn /tools_condensation.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
24.8 kB
import numpy as np
import awkward as ak
import tqdm
import time
import torch
from collections import defaultdict, Counter
from src.utils.metrics import evaluate_metrics
from src.data.tools import _concat
from src.logger.logger import _logger
from torch_scatter import scatter_sum, scatter_max
import wandb
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from pathlib import Path
from src.layers.object_cond import calc_eta_phi
import os
import pickle
from src.dataset.functions_data import get_batch, get_corrected_batch
from src.plotting.plot_event import plot_batch_eval_OC, get_labels_jets
from src.jetfinder.clustering import get_clustering_labels
from src.evaluation.clustering_metrics import compute_f1_score_from_result
from src.utils.train_utils import get_target_obj_score, plot_obj_score_debug # for debugging only!
from src.layers.object_cond import loss_func_aug
def train_epoch(
args,
model,
loss_func,
gt_func,
opt,
scheduler,
train_loader,
dev,
epoch,
grad_scaler=None,
local_rank=0,
current_step=0,
val_loader=None,
batch_config=None,
val_dataset=None,
obj_score_model=None,
opt_obj_score=None,
sched_obj_score=None,
train_loader_aug=None, # if it's not None, it will also use the augmented events for the IRC safety loss term
):
if obj_score_model is None:
model.train()
else:
obj_score_model.train()
step_count = current_step
start_time = time.time()
prev_time = time.time()
if train_loader_aug is not None:
train_loader_aug = iter(train_loader_aug)
for event_batch in tqdm.tqdm(train_loader):
time_preprocess_start = time.time()
y = gt_func(event_batch)
batch, y = get_batch(event_batch, batch_config, y)
if train_loader_aug is not None:
event_batch_aug = next(train_loader_aug)
assert event_batch_aug.pfcands.original_particle_mapping.max() < len(event_batch.pfcands), f"The original particle mapping out of bounds: {event_batch_aug.pfcands_original_particle_mapping.max()} >= {len(event_batch.pfcands)}"
if len(batch.dropped_batches):
print("Dropped batches:", batch.dropped_batches, " - skipping this iteration")
# Quicker this than to implement all the indexing complications from dropped batches
continue
y_aug = gt_func(event_batch_aug)
#print("len(event_batch_aug):", len(event_batch_aug))
#print("len(event_batch):", len(event_batch))
#print("number of pfcands:", len(event_batch.pfcands.pt), len(event_batch_aug.pfcands.pt))
batch_aug, y_aug = get_batch(event_batch_aug, batch_config, y_aug)
time_preprocess_end = time.time()
step_count += 1
y = y.to(dev)
opt.zero_grad()
if obj_score_model is not None:
opt_obj_score.zero_grad()
torch.autograd.set_detect_anomaly(True)
with torch.cuda.amp.autocast(enabled=grad_scaler is not None):
batch.to(dev)
if train_loader_aug is not None:
batch_aug.to(dev)
model_forward_time_start = time.time()
if obj_score_model is not None:
with torch.no_grad():
y_pred = model(batch) # Only train the objectness score model
else:
y_pred = model(batch)
if train_loader_aug is not None:
y_pred_aug = model(batch_aug)
model_forward_time_end = time.time()
loss, loss_dict = loss_func(batch, y_pred, y)
if train_loader_aug is not None:
loss_aug = loss_func_aug(y_pred, y_pred_aug, batch, batch_aug, event_batch, event_batch_aug)
loss += loss_aug * 100.0
loss_dict["loss_IRC"] = loss_aug
loss_time_end = time.time()
wandb.log({
"time_preprocess": time_preprocess_end - time_preprocess_start,
"time_model_forward": model_forward_time_end - model_forward_time_start,
"time_loss": loss_time_end - model_forward_time_end,
}, step=step_count)
if obj_score_model is not None:
# Compute the objectness score
coords = y_pred[:, 1:4]
# TODO: update this to match the model architecture, as it's written here it's only suitable for L-GATr
_, clusters, event_idx_clusters = get_clustering_labels(coords.detach().cpu().numpy(),
batch.batch_idx.detach().cpu().numpy(),
min_cluster_size=args.min_cluster_size,
min_samples=args.min_samples, epsilon=args.epsilon,
return_labels_event_idx=True)
# Loop through the events in a batch
input_pxyz = event_batch.pfcands.pxyz[batch.filter.cpu()]
#input_pt = torch.sqrt(torch.sum(input_pxyz[:, :2] ** 2, dim=-1))
clusters_pxyz = scatter_sum(input_pxyz, torch.tensor(clusters) + 1, dim=0)[1:]
#clusters_highest_pt_particle = scatter_max(input_pt, torch.tensor(clusters) + 1, dim=0)[0][1:]
clusters_eta, clusters_phi = calc_eta_phi(clusters_pxyz, return_stacked=False)
#pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=False)
clusters_pt = torch.norm(clusters_pxyz[:, :2], dim=-1)
filter = clusters_pt >= 100 # Don't train on the clusters that eventually get cut off
batch_corr = get_corrected_batch(batch, clusters, test=False)
if not args.global_features_obj_score:
objectness_score = obj_score_model(batch_corr)[filter].flatten() # Obj. score is [0, 1]
else:
objectness_score = obj_score_model(batch_corr, batch, clusters)[filter].flatten()
target_obj_score = get_target_obj_score(clusters_eta[filter], clusters_phi[filter], clusters_pt[filter],
torch.tensor(event_idx_clusters)[filter], y.dq_eta, y.dq_phi,
y.dq_coords_batch_idx, gt_mode=args.objectness_score_gt_mode)
#target_obj_score = clusters_highest_pt_particle[filter].to(objectness_score.device)
#fig = plot_obj_score_debug(y.dq_eta, y.dq_phi, y.dq_coords_batch_idx, clusters_eta[filter], clusters_phi[filter], clusters_pt[filter],
# torch.tensor(event_idx_clusters)[filter], target_obj_score, input_pxyz, batch.batch_idx.detach().cpu(), torch.tensor(clusters), objectness_score)
#fig.savefig(os.path.join(args.run_path, "obj_score_debug_{}.pdf".format(step_count)))
n_positive, n_negative = target_obj_score.sum(), (1-target_obj_score).sum()
# set weights for the loss according to the class imbalance
#pos_weight = n_negative / (n_positive + n_negative)
#neg_weight = n_positive / (n_positive + n_negative)
n_all = n_positive + n_negative
pos_weight = n_all / n_positive if n_positive > 0 else 0
neg_weight = n_all / n_negative if n_negative > 0 else 0
#print("Positive weight:", pos_weight, "Negative weight:", neg_weight)
#weight = pos_weight * target_obj_score + neg_weight * (1 - target_obj_score)
# Weights for BCELoss: per-element weight
weights = torch.where(target_obj_score == 1, pos_weight, neg_weight)
print("N positive:", n_positive.item(), "N negative:", n_negative.item())
print("First 20 predictions:", objectness_score[:20], "First 20 targets:", target_obj_score[:20])
objectness_score = objectness_score.clamp(min=-10, max=10)
target_obj_score = target_obj_score.to(objectness_score.device)
weights = weights.to(objectness_score.device)
##### TEMPORARY: PREDICT HIGHEST PT OF PARTICLE !!!!!! ######
#loss_obj_score = torch.mean(torch.square(target_obj_score - objectness_score)) # temporarily just regress the highest pt particle to check for expresiveness of the model
loss_obj_score = torch.nn.BCEWithLogitsLoss(weight=weights)(objectness_score, target_obj_score)
#loss_obj_score = torch.mean(weights * (objectness_score - target_obj_score) ** 2)
loss = loss_obj_score
loss_dict["loss_obj_score"] = loss_obj_score
if obj_score_model is None:
if grad_scaler is None:
loss.backward()
opt.step()
else:
grad_scaler.scale(loss).backward()
grad_scaler.step(opt)
grad_scaler.update()
else:
if grad_scaler is None:
loss.backward()
opt_obj_score.step()
else:
grad_scaler.scale(loss).backward()
grad_scaler.step(opt_obj_score)
grad_scaler.update()
step_end_time = time.time()
loss = loss.item()
wandb.log({key: value.detach().cpu().item() for key, value in loss_dict.items()}, step=step_count)
wandb.log({"loss": loss}, step=step_count)
del loss_dict
del loss
if (local_rank == 0) and (step_count % args.validation_steps) == 0:
dirname = args.run_path
if obj_score_model is None:
model_state_dict = (
model.module.state_dict()
if isinstance(
model,
(
torch.nn.DataParallel,
torch.nn.parallel.DistributedDataParallel,
),
)
else model.state_dict()
)
state_dict = {"model": model_state_dict, "optimizer": opt.state_dict(), "scheduler": scheduler.state_dict()}
path = os.path.join(dirname, "step_%d_epoch_%d.ckpt" % (step_count, epoch))
torch.save(
state_dict,
path
)
else:
model_state_dict = (
obj_score_model.module.state_dict()
if isinstance(
model,
(
torch.nn.DataParallel,
torch.nn.parallel.DistributedDataParallel,
),
)
else obj_score_model.state_dict()
)
sched_sd = {}
if sched_obj_score is not None:
sched_sd = sched_obj_score.state_dict()
state_dict = {"model": model_state_dict, "optimizer": opt_obj_score.state_dict(),
"scheduler": sched_sd}
path = os.path.join(dirname, "OS_step_%d_epoch_%d.ckpt" % (step_count, epoch))
torch.save(
state_dict,
path
)
res = evaluate(
model,
val_loader,
dev,
epoch,
step_count,
loss_func=loss_func,
gt_func=gt_func,
local_rank=local_rank,
args=args,
batch_config=batch_config,
predict=False,
model_obj_score=obj_score_model
)
if obj_score_model is not None:
res, res_obj_score, res_obj_score1 = res
# TODO: use the obj score here for quick evaluation
f1 = compute_f1_score_from_result(res, val_dataset)
wandb.log({"val_f1_score": f1}, step=step_count)
if args.num_steps != -1 and step_count >= args.num_steps:
print("Quitting training as the required number of steps has been reached.")
return "quit_training"
#_logger.info(
# "Epoch %d, step %d: loss=%.5f, time=%.2fs"
# % (epoch, step_count, loss, step_end_time - prev_time)
#)
time_diff = time.time() - start_time
return step_count
def evaluate(
model,
eval_loader,
dev,
epoch,
step,
loss_func,
gt_func,
local_rank=0,
args=None,
batch_config=None,
predict=False,
model_obj_score=None # if not None, it will compute the objectness score of each cluster using the proposed method
):
model.eval()
count = 0
start_time = time.time()
total_loss = 0
total_loss_dict = {}
plot_batches = [0, 1]
n_batches = 0
if predict or True: # predict also on validation set
predictions = {
"event_idx": [],
"GT_cluster": [],
"pred": [],
"eta": [],
"phi": [],
"pt": [],
"mass": [],
"AK8_cluster": [],
#"radius_cluster_GenJets": [],
#"radius_cluster_FatJets": [],
"model_cluster": [],
#"event_clusters_idx": []
}
if model_obj_score is not None:
obj_score_predictions = []
obj_score_targets = []
predictions["event_clusters_idx"] = []
if args.beta_type != "pt+bc":
del predictions["BC_score"]
last_event_idx = 0
with torch.no_grad():
with tqdm.tqdm(eval_loader) as tq:
for event_batch in tq:
count += event_batch.n_events # number of samples
y = gt_func(event_batch)
batch, y = get_batch(event_batch, batch_config, y, test=predict)
pfcands = event_batch.pfcands
if args.parton_level:
pfcands = event_batch.final_parton_level_particles
elif args.gen_level:
pfcands = event_batch.final_gen_particles
y = y.to(dev)
batch = batch.to(dev)
y_pred = model(batch)
if not predict:
loss, loss_dict = loss_func(batch, y_pred, y)
loss = loss.item()
total_loss += loss
for key in loss_dict:
if key not in total_loss_dict:
total_loss_dict[key] = 0
total_loss_dict[key] += loss_dict[key].item()
del loss_dict
if n_batches in plot_batches and not predict: # don't plot these for prediction - they are useful in training
plot_folder = os.path.join(args.run_path, "eval_plots", "epoch_" + str(epoch) + "_step_" + str(step))
Path(plot_folder).mkdir(parents=True, exist_ok=True)
if args.loss == "quark_distance":
label_true = y.labels_no_renumber.detach().cpu()
elif args.train_objectness_score:
label_true = y.labels.detach().cpu()
else:
label_true = y.detach().cpu()
#plot_batch_eval_OC(event_batch, label_true,
# y_pred.detach().cpu(), batch.batch_idx.detach().cpu(),
# os.path.join(plot_folder, "batch_" + str(n_batches) + ".pdf"),
# args=args, batch=n_batches, dropped_batches=batch.dropped_batches)
n_batches += 1
if not predict:
tq.set_postfix(
{
"Loss": "%.5f" % loss,
"AvgLoss": "%.5f" % (total_loss / n_batches),
}
)
if predict or True:
#print("Last event idx =", last_event_idx)
#print("Batch idx =", batch.batch_idx.tolist())
event_idx = batch.batch_idx + last_event_idx
#print("Event idx:", event_idx)
predictions["event_idx"].append(event_idx)
if not model_obj_score:
predictions["GT_cluster"].append(y.detach().cpu())
else:
predictions["GT_cluster"].append(y.labels.detach().cpu())
predictions["pred"].append(y_pred.detach().cpu())
predictions["eta"].append(pfcands.eta.detach().cpu())
predictions["phi"].append(pfcands.phi.detach().cpu())
predictions["pt"].append(pfcands.pt.detach().cpu())
predictions["AK8_cluster"].append(event_batch.pfcands.pf_cand_jet_idx.detach().cpu())
#predictions["radius_cluster_GenJets"].append(get_labels_jets(event_batch, event_batch.pfcands, event_batch.genjets).detach().cpu())
#predictions["radius_cluster_FatJets"].append(get_labels_jets(event_batch, event_batch.pfcands, event_batch.fatjets).detach().cpu())
predictions["mass"].append(pfcands.mass.detach().cpu())
if predictions["pred"][-1].shape[1] == 4:
coords = predictions["pred"][-1][:, :3]
else:
coords = predictions["pred"][-1][:, 1:4]
#if model_obj_score is None:
clustering_labels = torch.tensor(
get_clustering_labels(
coords.detach().cpu().numpy(),
event_idx.detach().cpu().numpy(),
min_cluster_size=args.min_cluster_size,
min_samples=args.min_samples,
epsilon=args.epsilon,
return_labels_event_idx=False)
)
if model_obj_score is not None:
_, clusters, event_idx_clusters = get_clustering_labels(coords.detach().cpu().numpy(),
batch.batch_idx.detach().cpu().numpy(),
min_cluster_size=args.min_cluster_size,
min_samples=args.min_samples,
epsilon=args.epsilon,
return_labels_event_idx=True)
assert len(event_idx_clusters) == clusters.max() + 1
batch_corr = get_corrected_batch(batch, clusters, test=predict)
input_pxyz = pfcands.pxyz[batch.filter.cpu()]
clusters_pxyz = scatter_sum(input_pxyz, torch.tensor(clusters) + 1, dim=0)[1:]
clusters_eta, clusters_phi = calc_eta_phi(clusters_pxyz, return_stacked=False)
# pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=False)
clusters_pt = torch.norm(clusters_pxyz[:, :2], dim=-1)
filter = clusters_pt >= 100 # Don't train on the clusters that eventually get cut off
if not args.global_features_obj_score:
objectness_score = model_obj_score(batch_corr)
else:
objectness_score = model_obj_score(batch_corr, batch, clusters)
obj_score_predictions.append(objectness_score.detach().cpu())
target_obj_score = get_target_obj_score(clusters_eta[filter], clusters_phi[filter],
clusters_pt[filter],
torch.tensor(event_idx_clusters)[filter], y.dq_eta,
y.dq_phi, y.dq_coords_batch_idx, gt_mode=args.objectness_score_gt_mode) # [filter]
n_positive, n_negative = target_obj_score.sum(), (1 - target_obj_score.float()).sum()
# set weights for the loss according to the class imbalance
# pos_weight = n_negative / (n_positive + n_negative)
# neg_weight = n_positive / (n_positive + n_negative)
n_all = n_positive + n_negative
pos_weight = n_all / n_positive if n_positive > 0 else 0
neg_weight = n_all / n_negative if n_negative > 0 else 0
# Weights for BCELoss: per-element weight
weights = torch.where(target_obj_score == 1, pos_weight, neg_weight)
print("N positive (eval):", n_positive.item(), "N negative (eval):", n_negative.item())
print("First 10 predictions (eval):", objectness_score[:20], "First 10 targets (eval):",
target_obj_score[:20])
objectness_score = objectness_score.clamp(min=-10, max=10)
target_obj_score = target_obj_score.to(objectness_score.device)
#print(target_obj_score.device, filter.device, objectness_score.device, weights.device)
weights = weights.to(objectness_score.device)
filter = filter.to(objectness_score.device)
loss_obj_score = torch.nn.BCEWithLogitsLoss(weight=weights)(objectness_score.flatten()[filter], target_obj_score.flatten()).cpu().item()
# compute ROC AUC
obj_score_targets.append(target_obj_score)
k = "val_loss_obj_score"
if k not in total_loss_dict:
total_loss_dict[k] = 0
total_loss_dict[k] += loss_obj_score
predictions["event_clusters_idx"].append(torch.tensor(event_idx_clusters) + last_event_idx)
# loss_obj_score = torch.mean(weights * (objectness_score - target_obj_score) ** 2)
predictions["model_cluster"].append(
torch.tensor(clustering_labels)
)
last_event_idx = count
if event_idx.max().item() + 1 != last_event_idx:
print(f"event_idx.max() = {event_idx.max().item()}, last_event_idx = {last_event_idx} - the eval would have failed here before the update")
#print("Setting new last_event_idx to", last_event_idx)
if local_rank == 0 and not predict:
wandb.log({"val_loss": total_loss / n_batches}, step=step)
wandb.log({"val_" + key: value / n_batches for key, value in total_loss_dict.items()}, step=step)
time_diff = time.time() - start_time
_logger.info(
"Evaluated on %d samples in total (avg. speed %.1f samples/s)"
% (count, count / time_diff)
)
if predict or True:
#for key in predictions:
# predictions[key] = torch.cat(predictions[key], dim=0)
#predictions = {key: torch.cat(predictions[key], dim=0) for key in predictions}
predictions_1 = {}
for key in predictions:
#print("key", key, predictions[key])
predictions_1[key] = torch.cat(predictions[key], dim=0)
predictions = predictions_1
#predictions["event_idx"] = torch.cat(predictions["event_idx"], dim=0)
#predictions["GT_cluster"] = torch.cat(predictions["GT_cluster"], dim=0)
#predictions["pred"] = torch.cat(predictions["pred"], dim=0)
#predictions["eta"] = torch.cat(predictions["eta"], dim=0)
#predictions["phi"] = torch.cat(predictions["phi"], dim=0)
#predictions["pt"] = torch.cat(predictions["pt"], dim=0)
#predictions["AK8_cluster"] = torch.cat(predictions["AK8_cluster"], dim=0)
#predictions["radius_cluster_GenJets"] = torch.cat(predictions["radius_cluster_GenJets"], dim=0)
#predictions["radius_cluster_FatJets"] = torch.cat(predictions["radius_cluster_FatJets"], dim=0)
#predictions["mass"] = torch.cat(predictions["mass"], dim=0)
#predictions["model_cluster"] = torch.cat(predictions["model_cluster"], dim=0)
if model_obj_score is not None:
return predictions, torch.cat(obj_score_predictions), torch.cat(obj_score_targets)
return predictions
return total_loss / count # Average loss is the validation metric here