import numpy as np import awkward as ak import tqdm import time import torch from collections import defaultdict, Counter from src.utils.metrics import evaluate_metrics from src.data.tools import _concat from src.logger.logger import _logger from torch_scatter import scatter_sum, scatter_max import wandb import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score from sklearn.metrics import confusion_matrix from pathlib import Path from src.layers.object_cond import calc_eta_phi import os import pickle from src.dataset.functions_data import get_batch, get_corrected_batch from src.plotting.plot_event import plot_batch_eval_OC, get_labels_jets from src.jetfinder.clustering import get_clustering_labels from src.evaluation.clustering_metrics import compute_f1_score_from_result from src.utils.train_utils import get_target_obj_score, plot_obj_score_debug # for debugging only! from src.layers.object_cond import loss_func_aug def train_epoch( args, model, loss_func, gt_func, opt, scheduler, train_loader, dev, epoch, grad_scaler=None, local_rank=0, current_step=0, val_loader=None, batch_config=None, val_dataset=None, obj_score_model=None, opt_obj_score=None, sched_obj_score=None, train_loader_aug=None, # if it's not None, it will also use the augmented events for the IRC safety loss term ): if obj_score_model is None: model.train() else: obj_score_model.train() step_count = current_step start_time = time.time() prev_time = time.time() if train_loader_aug is not None: train_loader_aug = iter(train_loader_aug) for event_batch in tqdm.tqdm(train_loader): time_preprocess_start = time.time() y = gt_func(event_batch) batch, y = get_batch(event_batch, batch_config, y) if train_loader_aug is not None: event_batch_aug = next(train_loader_aug) assert event_batch_aug.pfcands.original_particle_mapping.max() < len(event_batch.pfcands), f"The original particle mapping out of bounds: {event_batch_aug.pfcands_original_particle_mapping.max()} >= {len(event_batch.pfcands)}" if len(batch.dropped_batches): print("Dropped batches:", batch.dropped_batches, " - skipping this iteration") # Quicker this than to implement all the indexing complications from dropped batches continue y_aug = gt_func(event_batch_aug) #print("len(event_batch_aug):", len(event_batch_aug)) #print("len(event_batch):", len(event_batch)) #print("number of pfcands:", len(event_batch.pfcands.pt), len(event_batch_aug.pfcands.pt)) batch_aug, y_aug = get_batch(event_batch_aug, batch_config, y_aug) time_preprocess_end = time.time() step_count += 1 y = y.to(dev) opt.zero_grad() if obj_score_model is not None: opt_obj_score.zero_grad() torch.autograd.set_detect_anomaly(True) with torch.cuda.amp.autocast(enabled=grad_scaler is not None): batch.to(dev) if train_loader_aug is not None: batch_aug.to(dev) model_forward_time_start = time.time() if obj_score_model is not None: with torch.no_grad(): y_pred = model(batch) # Only train the objectness score model else: y_pred = model(batch) if train_loader_aug is not None: y_pred_aug = model(batch_aug) model_forward_time_end = time.time() loss, loss_dict = loss_func(batch, y_pred, y) if train_loader_aug is not None: loss_aug = loss_func_aug(y_pred, y_pred_aug, batch, batch_aug, event_batch, event_batch_aug) loss += loss_aug * 100.0 loss_dict["loss_IRC"] = loss_aug loss_time_end = time.time() wandb.log({ "time_preprocess": time_preprocess_end - time_preprocess_start, "time_model_forward": model_forward_time_end - model_forward_time_start, "time_loss": loss_time_end - model_forward_time_end, }, step=step_count) if obj_score_model is not None: # Compute the objectness score coords = y_pred[:, 1:4] # TODO: update this to match the model architecture, as it's written here it's only suitable for L-GATr _, clusters, event_idx_clusters = get_clustering_labels(coords.detach().cpu().numpy(), batch.batch_idx.detach().cpu().numpy(), min_cluster_size=args.min_cluster_size, min_samples=args.min_samples, epsilon=args.epsilon, return_labels_event_idx=True) # Loop through the events in a batch input_pxyz = event_batch.pfcands.pxyz[batch.filter.cpu()] #input_pt = torch.sqrt(torch.sum(input_pxyz[:, :2] ** 2, dim=-1)) clusters_pxyz = scatter_sum(input_pxyz, torch.tensor(clusters) + 1, dim=0)[1:] #clusters_highest_pt_particle = scatter_max(input_pt, torch.tensor(clusters) + 1, dim=0)[0][1:] clusters_eta, clusters_phi = calc_eta_phi(clusters_pxyz, return_stacked=False) #pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=False) clusters_pt = torch.norm(clusters_pxyz[:, :2], dim=-1) filter = clusters_pt >= 100 # Don't train on the clusters that eventually get cut off batch_corr = get_corrected_batch(batch, clusters, test=False) if not args.global_features_obj_score: objectness_score = obj_score_model(batch_corr)[filter].flatten() # Obj. score is [0, 1] else: objectness_score = obj_score_model(batch_corr, batch, clusters)[filter].flatten() target_obj_score = get_target_obj_score(clusters_eta[filter], clusters_phi[filter], clusters_pt[filter], torch.tensor(event_idx_clusters)[filter], y.dq_eta, y.dq_phi, y.dq_coords_batch_idx, gt_mode=args.objectness_score_gt_mode) #target_obj_score = clusters_highest_pt_particle[filter].to(objectness_score.device) #fig = plot_obj_score_debug(y.dq_eta, y.dq_phi, y.dq_coords_batch_idx, clusters_eta[filter], clusters_phi[filter], clusters_pt[filter], # torch.tensor(event_idx_clusters)[filter], target_obj_score, input_pxyz, batch.batch_idx.detach().cpu(), torch.tensor(clusters), objectness_score) #fig.savefig(os.path.join(args.run_path, "obj_score_debug_{}.pdf".format(step_count))) n_positive, n_negative = target_obj_score.sum(), (1-target_obj_score).sum() # set weights for the loss according to the class imbalance #pos_weight = n_negative / (n_positive + n_negative) #neg_weight = n_positive / (n_positive + n_negative) n_all = n_positive + n_negative pos_weight = n_all / n_positive if n_positive > 0 else 0 neg_weight = n_all / n_negative if n_negative > 0 else 0 #print("Positive weight:", pos_weight, "Negative weight:", neg_weight) #weight = pos_weight * target_obj_score + neg_weight * (1 - target_obj_score) # Weights for BCELoss: per-element weight weights = torch.where(target_obj_score == 1, pos_weight, neg_weight) print("N positive:", n_positive.item(), "N negative:", n_negative.item()) print("First 20 predictions:", objectness_score[:20], "First 20 targets:", target_obj_score[:20]) objectness_score = objectness_score.clamp(min=-10, max=10) target_obj_score = target_obj_score.to(objectness_score.device) weights = weights.to(objectness_score.device) ##### TEMPORARY: PREDICT HIGHEST PT OF PARTICLE !!!!!! ###### #loss_obj_score = torch.mean(torch.square(target_obj_score - objectness_score)) # temporarily just regress the highest pt particle to check for expresiveness of the model loss_obj_score = torch.nn.BCEWithLogitsLoss(weight=weights)(objectness_score, target_obj_score) #loss_obj_score = torch.mean(weights * (objectness_score - target_obj_score) ** 2) loss = loss_obj_score loss_dict["loss_obj_score"] = loss_obj_score if obj_score_model is None: if grad_scaler is None: loss.backward() opt.step() else: grad_scaler.scale(loss).backward() grad_scaler.step(opt) grad_scaler.update() else: if grad_scaler is None: loss.backward() opt_obj_score.step() else: grad_scaler.scale(loss).backward() grad_scaler.step(opt_obj_score) grad_scaler.update() step_end_time = time.time() loss = loss.item() wandb.log({key: value.detach().cpu().item() for key, value in loss_dict.items()}, step=step_count) wandb.log({"loss": loss}, step=step_count) del loss_dict del loss if (local_rank == 0) and (step_count % args.validation_steps) == 0: dirname = args.run_path if obj_score_model is None: model_state_dict = ( model.module.state_dict() if isinstance( model, ( torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, ), ) else model.state_dict() ) state_dict = {"model": model_state_dict, "optimizer": opt.state_dict(), "scheduler": scheduler.state_dict()} path = os.path.join(dirname, "step_%d_epoch_%d.ckpt" % (step_count, epoch)) torch.save( state_dict, path ) else: model_state_dict = ( obj_score_model.module.state_dict() if isinstance( model, ( torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, ), ) else obj_score_model.state_dict() ) sched_sd = {} if sched_obj_score is not None: sched_sd = sched_obj_score.state_dict() state_dict = {"model": model_state_dict, "optimizer": opt_obj_score.state_dict(), "scheduler": sched_sd} path = os.path.join(dirname, "OS_step_%d_epoch_%d.ckpt" % (step_count, epoch)) torch.save( state_dict, path ) res = evaluate( model, val_loader, dev, epoch, step_count, loss_func=loss_func, gt_func=gt_func, local_rank=local_rank, args=args, batch_config=batch_config, predict=False, model_obj_score=obj_score_model ) if obj_score_model is not None: res, res_obj_score, res_obj_score1 = res # TODO: use the obj score here for quick evaluation f1 = compute_f1_score_from_result(res, val_dataset) wandb.log({"val_f1_score": f1}, step=step_count) if args.num_steps != -1 and step_count >= args.num_steps: print("Quitting training as the required number of steps has been reached.") return "quit_training" #_logger.info( # "Epoch %d, step %d: loss=%.5f, time=%.2fs" # % (epoch, step_count, loss, step_end_time - prev_time) #) time_diff = time.time() - start_time return step_count def evaluate( model, eval_loader, dev, epoch, step, loss_func, gt_func, local_rank=0, args=None, batch_config=None, predict=False, model_obj_score=None # if not None, it will compute the objectness score of each cluster using the proposed method ): model.eval() count = 0 start_time = time.time() total_loss = 0 total_loss_dict = {} plot_batches = [0, 1] n_batches = 0 if predict or True: # predict also on validation set predictions = { "event_idx": [], "GT_cluster": [], "pred": [], "eta": [], "phi": [], "pt": [], "mass": [], "AK8_cluster": [], #"radius_cluster_GenJets": [], #"radius_cluster_FatJets": [], "model_cluster": [], #"event_clusters_idx": [] } if model_obj_score is not None: obj_score_predictions = [] obj_score_targets = [] predictions["event_clusters_idx"] = [] if args.beta_type != "pt+bc": del predictions["BC_score"] last_event_idx = 0 with torch.no_grad(): with tqdm.tqdm(eval_loader) as tq: for event_batch in tq: count += event_batch.n_events # number of samples y = gt_func(event_batch) batch, y = get_batch(event_batch, batch_config, y, test=predict) pfcands = event_batch.pfcands if args.parton_level: pfcands = event_batch.final_parton_level_particles elif args.gen_level: pfcands = event_batch.final_gen_particles y = y.to(dev) batch = batch.to(dev) y_pred = model(batch) if not predict: loss, loss_dict = loss_func(batch, y_pred, y) loss = loss.item() total_loss += loss for key in loss_dict: if key not in total_loss_dict: total_loss_dict[key] = 0 total_loss_dict[key] += loss_dict[key].item() del loss_dict if n_batches in plot_batches and not predict: # don't plot these for prediction - they are useful in training plot_folder = os.path.join(args.run_path, "eval_plots", "epoch_" + str(epoch) + "_step_" + str(step)) Path(plot_folder).mkdir(parents=True, exist_ok=True) if args.loss == "quark_distance": label_true = y.labels_no_renumber.detach().cpu() elif args.train_objectness_score: label_true = y.labels.detach().cpu() else: label_true = y.detach().cpu() #plot_batch_eval_OC(event_batch, label_true, # y_pred.detach().cpu(), batch.batch_idx.detach().cpu(), # os.path.join(plot_folder, "batch_" + str(n_batches) + ".pdf"), # args=args, batch=n_batches, dropped_batches=batch.dropped_batches) n_batches += 1 if not predict: tq.set_postfix( { "Loss": "%.5f" % loss, "AvgLoss": "%.5f" % (total_loss / n_batches), } ) if predict or True: #print("Last event idx =", last_event_idx) #print("Batch idx =", batch.batch_idx.tolist()) event_idx = batch.batch_idx + last_event_idx #print("Event idx:", event_idx) predictions["event_idx"].append(event_idx) if not model_obj_score: predictions["GT_cluster"].append(y.detach().cpu()) else: predictions["GT_cluster"].append(y.labels.detach().cpu()) predictions["pred"].append(y_pred.detach().cpu()) predictions["eta"].append(pfcands.eta.detach().cpu()) predictions["phi"].append(pfcands.phi.detach().cpu()) predictions["pt"].append(pfcands.pt.detach().cpu()) predictions["AK8_cluster"].append(event_batch.pfcands.pf_cand_jet_idx.detach().cpu()) #predictions["radius_cluster_GenJets"].append(get_labels_jets(event_batch, event_batch.pfcands, event_batch.genjets).detach().cpu()) #predictions["radius_cluster_FatJets"].append(get_labels_jets(event_batch, event_batch.pfcands, event_batch.fatjets).detach().cpu()) predictions["mass"].append(pfcands.mass.detach().cpu()) if predictions["pred"][-1].shape[1] == 4: coords = predictions["pred"][-1][:, :3] else: coords = predictions["pred"][-1][:, 1:4] #if model_obj_score is None: clustering_labels = torch.tensor( get_clustering_labels( coords.detach().cpu().numpy(), event_idx.detach().cpu().numpy(), min_cluster_size=args.min_cluster_size, min_samples=args.min_samples, epsilon=args.epsilon, return_labels_event_idx=False) ) if model_obj_score is not None: _, clusters, event_idx_clusters = get_clustering_labels(coords.detach().cpu().numpy(), batch.batch_idx.detach().cpu().numpy(), min_cluster_size=args.min_cluster_size, min_samples=args.min_samples, epsilon=args.epsilon, return_labels_event_idx=True) assert len(event_idx_clusters) == clusters.max() + 1 batch_corr = get_corrected_batch(batch, clusters, test=predict) input_pxyz = pfcands.pxyz[batch.filter.cpu()] clusters_pxyz = scatter_sum(input_pxyz, torch.tensor(clusters) + 1, dim=0)[1:] clusters_eta, clusters_phi = calc_eta_phi(clusters_pxyz, return_stacked=False) # pfcands_eta, pfcands_phi = calc_eta_phi(input_pxyz, return_stacked=False) clusters_pt = torch.norm(clusters_pxyz[:, :2], dim=-1) filter = clusters_pt >= 100 # Don't train on the clusters that eventually get cut off if not args.global_features_obj_score: objectness_score = model_obj_score(batch_corr) else: objectness_score = model_obj_score(batch_corr, batch, clusters) obj_score_predictions.append(objectness_score.detach().cpu()) target_obj_score = get_target_obj_score(clusters_eta[filter], clusters_phi[filter], clusters_pt[filter], torch.tensor(event_idx_clusters)[filter], y.dq_eta, y.dq_phi, y.dq_coords_batch_idx, gt_mode=args.objectness_score_gt_mode) # [filter] n_positive, n_negative = target_obj_score.sum(), (1 - target_obj_score.float()).sum() # set weights for the loss according to the class imbalance # pos_weight = n_negative / (n_positive + n_negative) # neg_weight = n_positive / (n_positive + n_negative) n_all = n_positive + n_negative pos_weight = n_all / n_positive if n_positive > 0 else 0 neg_weight = n_all / n_negative if n_negative > 0 else 0 # Weights for BCELoss: per-element weight weights = torch.where(target_obj_score == 1, pos_weight, neg_weight) print("N positive (eval):", n_positive.item(), "N negative (eval):", n_negative.item()) print("First 10 predictions (eval):", objectness_score[:20], "First 10 targets (eval):", target_obj_score[:20]) objectness_score = objectness_score.clamp(min=-10, max=10) target_obj_score = target_obj_score.to(objectness_score.device) #print(target_obj_score.device, filter.device, objectness_score.device, weights.device) weights = weights.to(objectness_score.device) filter = filter.to(objectness_score.device) loss_obj_score = torch.nn.BCEWithLogitsLoss(weight=weights)(objectness_score.flatten()[filter], target_obj_score.flatten()).cpu().item() # compute ROC AUC obj_score_targets.append(target_obj_score) k = "val_loss_obj_score" if k not in total_loss_dict: total_loss_dict[k] = 0 total_loss_dict[k] += loss_obj_score predictions["event_clusters_idx"].append(torch.tensor(event_idx_clusters) + last_event_idx) # loss_obj_score = torch.mean(weights * (objectness_score - target_obj_score) ** 2) predictions["model_cluster"].append( torch.tensor(clustering_labels) ) last_event_idx = count if event_idx.max().item() + 1 != last_event_idx: print(f"event_idx.max() = {event_idx.max().item()}, last_event_idx = {last_event_idx} - the eval would have failed here before the update") #print("Setting new last_event_idx to", last_event_idx) if local_rank == 0 and not predict: wandb.log({"val_loss": total_loss / n_batches}, step=step) wandb.log({"val_" + key: value / n_batches for key, value in total_loss_dict.items()}, step=step) time_diff = time.time() - start_time _logger.info( "Evaluated on %d samples in total (avg. speed %.1f samples/s)" % (count, count / time_diff) ) if predict or True: #for key in predictions: # predictions[key] = torch.cat(predictions[key], dim=0) #predictions = {key: torch.cat(predictions[key], dim=0) for key in predictions} predictions_1 = {} for key in predictions: #print("key", key, predictions[key]) predictions_1[key] = torch.cat(predictions[key], dim=0) predictions = predictions_1 #predictions["event_idx"] = torch.cat(predictions["event_idx"], dim=0) #predictions["GT_cluster"] = torch.cat(predictions["GT_cluster"], dim=0) #predictions["pred"] = torch.cat(predictions["pred"], dim=0) #predictions["eta"] = torch.cat(predictions["eta"], dim=0) #predictions["phi"] = torch.cat(predictions["phi"], dim=0) #predictions["pt"] = torch.cat(predictions["pt"], dim=0) #predictions["AK8_cluster"] = torch.cat(predictions["AK8_cluster"], dim=0) #predictions["radius_cluster_GenJets"] = torch.cat(predictions["radius_cluster_GenJets"], dim=0) #predictions["radius_cluster_FatJets"] = torch.cat(predictions["radius_cluster_FatJets"], dim=0) #predictions["mass"] = torch.cat(predictions["mass"], dim=0) #predictions["model_cluster"] = torch.cat(predictions["model_cluster"], dim=0) if model_obj_score is not None: return predictions, torch.cat(obj_score_predictions), torch.cat(obj_score_targets) return predictions return total_loss / count # Average loss is the validation metric here