if alternate_steps is not None: if not hasattr(model.mod, "current_state_alternate_steps"): model.mod.current_state_alternate_steps = 0 if alternate_steps is not None and step_count % alternate_steps == 0: print("Flipping steps") state = model.mod.current_state_alternate_steps state = 1 - state model.mod.current_state_alternate_steps = state wandb.log( {"current_state_alternate_steps": model.mod.current_state_alternate_steps} ) if state == 0: print("Switched to beta loss") model.mod.beta_weight = ( 1.0 # set this to zero for no beta loss (when it's frozen) ) model.mod.beta_exp_weight = 1.0 model.mod.attr_rep_weight = 0.0 else: print("Switched to clustering loss") model.mod.beta_weight = ( 0.0 # set this to zero for no beta loss (when it's frozen) ) model.mod.beta_exp_weight = 0.0 model.mod.attr_rep_weight = 1.0 # if clust_loss_only and calc_e_frac_loss and logwandb and local_rank == 0: # wandb.log( # { # "loss e frac": loss_E_frac, # "loss e frac true": loss_E_frac_true, # } # ) if tb_helper: print("tb_helper!", tb_helper) tb_helper.write_scalars( [ ("Loss/train", loss, tb_helper.batch_train_count + num_batches), ] ) if tb_helper.custom_fn: with torch.no_grad(): tb_helper.custom_fn( model_output=model_output, model=model, epoch=epoch, i_batch=num_batches, mode="train", ) # fig, ax = plt.subplots() # repulsive, attractive = ( # lst_nonzero(losses[16].detach().cpu().flatten()), # lst_nonzero(losses[17].detach().cpu().flatten()), # ) # ax.hist( # repulsive.view(-1), # bins=100, # alpha=0.5, # label="repulsive", # color="r", # ) # ax.hist( # attractive.view(-1), # bins=100, # alpha=0.5, # label="attractive", # color="b", # ) # ax.set_yscale("log") # ax.legend() # wandb.log({"rep. and att. norms": wandb.Image(fig)}) # plt.close(fig) if tb_helper: tb_helper.write_scalars( [ ("Loss/train (epoch)", total_loss / num_batches, epoch), ("MSE/train (epoch)", sum_sqr_err / count, epoch), ("MAE/train (epoch)", sum_abs_err / count, epoch), ] ) if tb_helper.custom_fn: with torch.no_grad(): tb_helper.custom_fn( model_output=model_output, model=model, epoch=epoch, i_batch=-1, mode="train", ) # update the batch state tb_helper.batch_train_count += num_batches def inference_statistics( model, train_loader, dev, grad_scaler=None, loss_terms=[], args=None, radius=0.7, total_num_batches=10, save_ckpt_to_folder=None, ): model.eval() clust_loss_only = loss_terms[0] add_energy_loss = loss_terms[1] num_batches = 0 loss_E_fracs = [] loss_E_fracs_true = [] loss_E_fracs_true_nopart = [] loss_E_fracs_nopart = [] part_E_true = [] part_PID_true = [] betas_list = [] figs = [] reco_counts, non_reco_counts = {}, {} total_counts = {} with tqdm.tqdm(train_loader) as tq: for batch_g, y in tq: with torch.cuda.amp.autocast(enabled=grad_scaler is not None): batch_g = batch_g.to(dev) if args.loss_regularization: model_output, loss_regularizing_neig, loss_ll = model(batch_g) else: model_output = model(batch_g, 1) preds = model_output.squeeze() ( loss, losses, loss_E_frac, loss_E_frac_true, loss_E_frac_nopart, loss_E_frac_true_nopart, ) = object_condensation_loss2( batch_g, model_output, y, clust_loss_only=clust_loss_only, add_energy_loss=add_energy_loss, calc_e_frac_loss=True, e_frac_loss_return_particles=True, q_min=args.qmin, frac_clustering_loss=args.frac_cluster_loss, attr_weight=args.L_attractive_weight, repul_weight=args.L_repulsive_weight, fill_loss_weight=args.fill_loss_weight, use_average_cc_pos=args.use_average_cc_pos, hgcalloss=args.hgcalloss, e_frac_loss_radius=radius, ) ( loss_E_frac_true, particle_ids_all, reco_count, non_reco_count, total_count, ) = loss_E_frac_true ( loss_E_frac_true_nopart, particle_ids_all_nopart, reco_count_nopart, non_reco_count_nopart, total_count_nopart, ) = loss_E_frac_true_nopart update_dict(reco_counts, reco_count_nopart) update_dict(total_counts, total_count_nopart) if len(reco_count): assert len(reco_counts) >= len(reco_count_nopart) update_dict(non_reco_counts, non_reco_count_nopart) loss_E_fracs.append([x.cpu() for x in loss_E_frac]) loss_E_fracs_true.append([x.cpu() for x in loss_E_frac_true]) loss_E_fracs_true_nopart.append( [x.cpu() for x in loss_E_frac_true_nopart] ) loss_E_fracs_nopart.append([x.cpu() for x in loss_E_frac_nopart]) part_PID_true.append( [ y[torch.tensor(pidall) - 1, 6].long() for pidall in particle_ids_all ] ) part_E_true.append( [y[torch.tensor(pidall) - 1, 3] for pidall in particle_ids_all] ) if clust_loss_only: clust_space_dim = 3 else: clust_space_dim = model.mod.output_dim - 28 xj = model_output[:, 0:clust_space_dim] # if model.mod.clust_space_norm == "twonorm": # xj = torch.nn.functional.normalize(xj, dim=1) # elif model.mod.clust_space_norm == "tanh": # xj = torch.tanh(xj) # elif model.mod.clust_space_norm == "none": # pass bj = torch.sigmoid( torch.reshape(model_output[:, clust_space_dim], [-1, 1]) ) # 3: betas bj = bj.clip(0.0, 1 - 1e-4) q = bj.arctanh() ** 2 + args.qmin fig, ax = plot_clust( batch_g, q, xj, y=y, radius=radius, loss_e_frac=loss_E_fracs[-1], betas=bj, ) betas = ( torch.sigmoid( torch.reshape(preds[:, args.clustering_space_dim], [-1, 1]) ) .detach() .cpu() .numpy() ) # figs.append(fig) betas_list.append(betas) num_batches += 1 if num_batches % 5 == 0 and save_ckpt_to_folder is not None: Path(save_ckpt_to_folder).mkdir(parents=True, exist_ok=True) loss_E_fracs_fold = [ item for sublist in loss_E_fracs for item in sublist ] loss_E_fracs_fold = torch.concat(loss_E_fracs_fold).flatten() loss_E_fracs_true_fold = [ item for sublist in loss_E_fracs_true for item in sublist ] loss_E_fracs_true_fold = torch.concat(loss_E_fracs_true_fold).flatten() part_E_true_fold = [item for sublist in part_E_true for item in sublist] part_E_true_fold = torch.concat(part_E_true_fold).flatten() part_PID_true_fold = [ item for sublist in part_PID_true for item in sublist ] part_PID_true_fold = torch.concat(part_PID_true_fold).flatten() loss_E_fracs_nopart_fold = [ item for sublist in loss_E_fracs_nopart for item in sublist ] loss_E_fracs_true_nopart_fold = [ item for sublist in loss_E_fracs_true_nopart for item in sublist ] obj = { "loss_e_fracs_nopart": loss_E_fracs_nopart_fold, "loss_e_fracs_true_nopart": loss_E_fracs_true_nopart_fold, "loss_e_fracs": loss_E_fracs_fold, "loss_e_fracs_true": loss_E_fracs_true_fold, "part_E_true": part_E_true_fold, "part_PID_true": part_PID_true_fold, "reco_counts": reco_counts, "non_reco_counts": non_reco_counts, "total_counts": total_counts, } file_to_save = os.path.join(save_ckpt_to_folder, "temp_ckpt" + ".pkl") with open(file_to_save, "wb") as f: pickle.dump(obj, f) if num_batches >= total_num_batches: break # flatten the lists if save_ckpt_to_folder is not None: return loss_E_fracs = [item for sublist in loss_E_fracs for item in sublist] loss_E_fracs = torch.concat(loss_E_fracs).flatten() loss_E_fracs_true = [item for sublist in loss_E_fracs_true for item in sublist] loss_E_fracs_true = torch.concat(loss_E_fracs_true).flatten() part_E_true = [item for sublist in part_E_true for item in sublist] part_E_true = torch.concat(part_E_true).flatten() part_PID_true = [item for sublist in part_PID_true for item in sublist] part_PID_true = torch.concat(part_PID_true).flatten() loss_E_fracs_nopart = [ item for sublist in loss_E_fracs_nopart for item in sublist ] loss_E_fracs_true_nopart = [ item for sublist in loss_E_fracs_true_nopart for item in sublist ] return { "loss_e_fracs": loss_E_fracs, "loss_e_fracs_true": loss_E_fracs_true, "loss_e_fracs_nopart": loss_E_fracs_nopart, "loss_e_fracs_true_nopart": loss_E_fracs_true_nopart, "betas": betas_list, "part_E_true": part_E_true, "part_PID_true": part_PID_true, "reco_counts": reco_counts, "non_reco_counts": non_reco_counts, "total_counts": total_counts, } def inference(model, test_loader, dev): """ Similar to evaluate_regression, but without the ground truth labels. """ model.eval() num_batches = 0 count = 0 results = [] start_time = time.time() with torch.no_grad(): with tqdm.tqdm(test_loader) as tq: for batch_g, _ in tq: batch_g = batch_g.to(dev) model_output = model(batch_g) # preds = model_output.squeeze().float() preds = model.mod.object_condensation_inference(batch_g, model_output) num_batches += 1 results.append(preds) time_diff = time.time() - start_time _logger.info( "Processed %d entries in total (avg. speed %.1f entries/s)" % (count, count / time_diff) ) return results #! create output graph with shower id ndata and store it for each event # if args.store_output: # print("calculating clustering and matching showers") # if step == 0 and local_rank == 0: # create_and_store_graph_output( # batch_g, # model_output, # y, # local_rank, # step, # epoch, # path_save=args.model_prefix + "/showers_df", # store=True, # ) # losses_cpu = [ # x.detach().to("cpu") if isinstance(x, torch.Tensor) else x # for x in losses # ] # all_val_losses.append(losses_cpu) # all_val_loss.append(loss.detach().to("cpu").item()) # pid_true, pid_pred = torch.cat( # [torch.tensor(x[7]) for x in all_val_losses] # ), torch.cat([torch.tensor(x[8]) for x in all_val_losses]) # pid_true, pid_pred = pid_true.tolist(), pid_pred.tolist() # , step=step) # if clust_loss_only and calc_e_frac_loss: # wandb.log( # { # "loss e frac val": loss_E_frac, # "loss e frac true val": loss_E_frac_true, # } # ) # ks = sorted(list(all_val_losses[0][9].keys())) # concatenated = {} # for key in ks: # concatenated[key] = np.concatenate([x[9][key] for x in all_val_losses]) # tables = {} # for key in ks: # tables[key] = concatenated[ # key # ] # wandb.Table(data=[[x] for x in concatenated[key]], columns=[key]) # wandb.log( # { # "val " + key: wandb.Histogram(clip_list(tables[key]), num_bins=100) # for key in ks # } # ) # , step=step) # scores = np.concatenate(scores) # labels = {k: _concat(v) for k, v in labels.items()} # metric_results = evaluate_metrics(labels[data_config.label_names[0]], scores, eval_metrics=eval_metrics) # _logger.info('Evaluation metrics: \n%s', '\n'.join( # [' - %s: \n%s' % (k, str(v)) for k, v in metric_results.items()])) def plot_regression_resolution(model, test_loader, dev, **kwargs): model.eval() results = [] # resolution results pid_classification_results = [] with torch.no_grad(): with tqdm.tqdm(test_loader) as tq: for batch_g, y in tq: batch_g = batch_g.to(dev) if args.loss_regularization: model_output, loss_regularizing_neig = model(batch_g) else: model_output = model(batch_g) resolutions, pid_true, pid_pred = model.mod.object_condensation_loss2( batch_g, model_output, y, return_resolution=True, q_min=args.qmin, frac_clustering_loss=0, use_average_cc_pos=args.use_average_cc_pos, hgcalloss=args.hgcalloss, ) results.append(resolutions) pid_classification_results.append((pid_true, pid_pred)) result_dict = {} for key in results[0]: result_dict[key] = np.concatenate([r[key] for r in results]) result_dict["event_by_event_accuracy"] = [ accuracy_score(pid_true.argmax(dim=0), pid_pred.argmax(dim=0)) for pid_true, pid_pred in pid_classification_results ] # just plot all for now result = {} for key in results[0]: data = result_dict[key] fig, ax = plt.subplots() ax.hist(data, bins=100, range=(-1.5, 1.5), histtype="step", label=key) ax.set_xlabel("resolution") ax.set_ylabel("count") ax.legend() result[key] = fig conf_mat = confusion_matrix(pid_true.argmax(dim=0), pid_pred.argmax(dim=0)) # confusion matrix fig, ax = plt.subplots(figsize=(7.5, 7.5)) # add onehot_particle_arr as class names class_names = onehot_particles_arr im = ax.matshow(conf_mat, cmap=plt.cm.Blues) ax.set_xticks(np.arange(len(class_names)), class_names, rotation=45) ax.set_yticks(np.arange(len(class_names)), class_names) result["PID_confusion_matrix"] = fig return result # if args.loss_regularization: # wandb.log({"loss regul neigh": loss_regularizing_neig}) # wandb.log({"loss ll": loss_ll}) # if (num_batches - 1) % 100 == 0: # if clust_loss_only: # clust_space_dim = 3 # model.mod.output_dim - 1 # else: # clust_space_dim = model.mod.output_dim - 28 # bj = torch.sigmoid( # torch.reshape(model_output[:, clust_space_dim], [-1, 1]) # ) # 3: betas # xj = model_output[:, 0:clust_space_dim] # xj: cluster space coords # # assert len(bj) == len(xj) # # if model.mod.clust_space_norm == "twonorm": # # xj = torch.nn.functional.normalize( # # xj, dim=1 # # ) # 0, 1, 2: cluster space coords # # elif model.mod.clust_space_norm == "tanh": # # xj = torch.tanh(xj) # # elif model.mod.clust_space_norm == "none": # # pass # bj = bj.clip(0.0, 1 - 1e-4) # q = bj.arctanh() ** 2 + args.qmin # assert q.shape[0] == xj.shape[0] # assert batch_g.ndata["h"].shape[0] == xj.shape[0] # fig, ax = plot_clust( # batch_g, # q, # xj, # title_prefix="train ep. {}, batch {}".format( # epoch, num_batches # ), # y=y, # betas=bj, # ) # wandb.log({"clust": wandb.Image(fig)}) # fig.clf() # # if (num_batches - 1) % 500 == 0: # # wandb.log( # # { # # "conf_mat_train": wandb.plot.confusion_matrix( # # y_true=pid_true, # # preds=pid_pred, # # class_names=class_names, # # ) # # } # # ) # ks = sorted(list(losses[9].keys())) # losses_cpu = [ # x.detach().to("cpu") if isinstance(x, torch.Tensor) else x # for x in losses # ] # tables = {} # for key in ks: # tables[key] = losses[9][ # key # ] # wandb.Table(data=[[x] for x in losses[9][key]], columns=[key]) # if local_rank == 0: # wandb.log( # { # key: wandb.Histogram(clip_list(tables[key]), num_bins=100) # for key, val in losses_cpu[9].items() # } # ) # , step=step_count) # return loss_epoch_total, losses_epoch_total