Spaces:
Sleeping
Sleeping
if alternate_steps is not None: | |
if not hasattr(model.mod, "current_state_alternate_steps"): | |
model.mod.current_state_alternate_steps = 0 | |
if alternate_steps is not None and step_count % alternate_steps == 0: | |
print("Flipping steps") | |
state = model.mod.current_state_alternate_steps | |
state = 1 - state | |
model.mod.current_state_alternate_steps = state | |
wandb.log( | |
{"current_state_alternate_steps": model.mod.current_state_alternate_steps} | |
) | |
if state == 0: | |
print("Switched to beta loss") | |
model.mod.beta_weight = ( | |
1.0 # set this to zero for no beta loss (when it's frozen) | |
) | |
model.mod.beta_exp_weight = 1.0 | |
model.mod.attr_rep_weight = 0.0 | |
else: | |
print("Switched to clustering loss") | |
model.mod.beta_weight = ( | |
0.0 # set this to zero for no beta loss (when it's frozen) | |
) | |
model.mod.beta_exp_weight = 0.0 | |
model.mod.attr_rep_weight = 1.0 | |
# if clust_loss_only and calc_e_frac_loss and logwandb and local_rank == 0: | |
# wandb.log( | |
# { | |
# "loss e frac": loss_E_frac, | |
# "loss e frac true": loss_E_frac_true, | |
# } | |
# ) | |
if tb_helper: | |
print("tb_helper!", tb_helper) | |
tb_helper.write_scalars( | |
[ | |
("Loss/train", loss, tb_helper.batch_train_count + num_batches), | |
] | |
) | |
if tb_helper.custom_fn: | |
with torch.no_grad(): | |
tb_helper.custom_fn( | |
model_output=model_output, | |
model=model, | |
epoch=epoch, | |
i_batch=num_batches, | |
mode="train", | |
) | |
# fig, ax = plt.subplots() | |
# repulsive, attractive = ( | |
# lst_nonzero(losses[16].detach().cpu().flatten()), | |
# lst_nonzero(losses[17].detach().cpu().flatten()), | |
# ) | |
# ax.hist( | |
# repulsive.view(-1), | |
# bins=100, | |
# alpha=0.5, | |
# label="repulsive", | |
# color="r", | |
# ) | |
# ax.hist( | |
# attractive.view(-1), | |
# bins=100, | |
# alpha=0.5, | |
# label="attractive", | |
# color="b", | |
# ) | |
# ax.set_yscale("log") | |
# ax.legend() | |
# wandb.log({"rep. and att. norms": wandb.Image(fig)}) | |
# plt.close(fig) | |
if tb_helper: | |
tb_helper.write_scalars( | |
[ | |
("Loss/train (epoch)", total_loss / num_batches, epoch), | |
("MSE/train (epoch)", sum_sqr_err / count, epoch), | |
("MAE/train (epoch)", sum_abs_err / count, epoch), | |
] | |
) | |
if tb_helper.custom_fn: | |
with torch.no_grad(): | |
tb_helper.custom_fn( | |
model_output=model_output, | |
model=model, | |
epoch=epoch, | |
i_batch=-1, | |
mode="train", | |
) | |
# update the batch state | |
tb_helper.batch_train_count += num_batches | |
def inference_statistics( | |
model, | |
train_loader, | |
dev, | |
grad_scaler=None, | |
loss_terms=[], | |
args=None, | |
radius=0.7, | |
total_num_batches=10, | |
save_ckpt_to_folder=None, | |
): | |
model.eval() | |
clust_loss_only = loss_terms[0] | |
add_energy_loss = loss_terms[1] | |
num_batches = 0 | |
loss_E_fracs = [] | |
loss_E_fracs_true = [] | |
loss_E_fracs_true_nopart = [] | |
loss_E_fracs_nopart = [] | |
part_E_true = [] | |
part_PID_true = [] | |
betas_list = [] | |
figs = [] | |
reco_counts, non_reco_counts = {}, {} | |
total_counts = {} | |
with tqdm.tqdm(train_loader) as tq: | |
for batch_g, y in tq: | |
with torch.cuda.amp.autocast(enabled=grad_scaler is not None): | |
batch_g = batch_g.to(dev) | |
if args.loss_regularization: | |
model_output, loss_regularizing_neig, loss_ll = model(batch_g) | |
else: | |
model_output = model(batch_g, 1) | |
preds = model_output.squeeze() | |
( | |
loss, | |
losses, | |
loss_E_frac, | |
loss_E_frac_true, | |
loss_E_frac_nopart, | |
loss_E_frac_true_nopart, | |
) = object_condensation_loss2( | |
batch_g, | |
model_output, | |
y, | |
clust_loss_only=clust_loss_only, | |
add_energy_loss=add_energy_loss, | |
calc_e_frac_loss=True, | |
e_frac_loss_return_particles=True, | |
q_min=args.qmin, | |
frac_clustering_loss=args.frac_cluster_loss, | |
attr_weight=args.L_attractive_weight, | |
repul_weight=args.L_repulsive_weight, | |
fill_loss_weight=args.fill_loss_weight, | |
use_average_cc_pos=args.use_average_cc_pos, | |
hgcalloss=args.hgcalloss, | |
e_frac_loss_radius=radius, | |
) | |
( | |
loss_E_frac_true, | |
particle_ids_all, | |
reco_count, | |
non_reco_count, | |
total_count, | |
) = loss_E_frac_true | |
( | |
loss_E_frac_true_nopart, | |
particle_ids_all_nopart, | |
reco_count_nopart, | |
non_reco_count_nopart, | |
total_count_nopart, | |
) = loss_E_frac_true_nopart | |
update_dict(reco_counts, reco_count_nopart) | |
update_dict(total_counts, total_count_nopart) | |
if len(reco_count): | |
assert len(reco_counts) >= len(reco_count_nopart) | |
update_dict(non_reco_counts, non_reco_count_nopart) | |
loss_E_fracs.append([x.cpu() for x in loss_E_frac]) | |
loss_E_fracs_true.append([x.cpu() for x in loss_E_frac_true]) | |
loss_E_fracs_true_nopart.append( | |
[x.cpu() for x in loss_E_frac_true_nopart] | |
) | |
loss_E_fracs_nopart.append([x.cpu() for x in loss_E_frac_nopart]) | |
part_PID_true.append( | |
[ | |
y[torch.tensor(pidall) - 1, 6].long() | |
for pidall in particle_ids_all | |
] | |
) | |
part_E_true.append( | |
[y[torch.tensor(pidall) - 1, 3] for pidall in particle_ids_all] | |
) | |
if clust_loss_only: | |
clust_space_dim = 3 | |
else: | |
clust_space_dim = model.mod.output_dim - 28 | |
xj = model_output[:, 0:clust_space_dim] | |
# if model.mod.clust_space_norm == "twonorm": | |
# xj = torch.nn.functional.normalize(xj, dim=1) | |
# elif model.mod.clust_space_norm == "tanh": | |
# xj = torch.tanh(xj) | |
# elif model.mod.clust_space_norm == "none": | |
# pass | |
bj = torch.sigmoid( | |
torch.reshape(model_output[:, clust_space_dim], [-1, 1]) | |
) # 3: betas | |
bj = bj.clip(0.0, 1 - 1e-4) | |
q = bj.arctanh() ** 2 + args.qmin | |
fig, ax = plot_clust( | |
batch_g, | |
q, | |
xj, | |
y=y, | |
radius=radius, | |
loss_e_frac=loss_E_fracs[-1], | |
betas=bj, | |
) | |
betas = ( | |
torch.sigmoid( | |
torch.reshape(preds[:, args.clustering_space_dim], [-1, 1]) | |
) | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
# figs.append(fig) | |
betas_list.append(betas) | |
num_batches += 1 | |
if num_batches % 5 == 0 and save_ckpt_to_folder is not None: | |
Path(save_ckpt_to_folder).mkdir(parents=True, exist_ok=True) | |
loss_E_fracs_fold = [ | |
item for sublist in loss_E_fracs for item in sublist | |
] | |
loss_E_fracs_fold = torch.concat(loss_E_fracs_fold).flatten() | |
loss_E_fracs_true_fold = [ | |
item for sublist in loss_E_fracs_true for item in sublist | |
] | |
loss_E_fracs_true_fold = torch.concat(loss_E_fracs_true_fold).flatten() | |
part_E_true_fold = [item for sublist in part_E_true for item in sublist] | |
part_E_true_fold = torch.concat(part_E_true_fold).flatten() | |
part_PID_true_fold = [ | |
item for sublist in part_PID_true for item in sublist | |
] | |
part_PID_true_fold = torch.concat(part_PID_true_fold).flatten() | |
loss_E_fracs_nopart_fold = [ | |
item for sublist in loss_E_fracs_nopart for item in sublist | |
] | |
loss_E_fracs_true_nopart_fold = [ | |
item for sublist in loss_E_fracs_true_nopart for item in sublist | |
] | |
obj = { | |
"loss_e_fracs_nopart": loss_E_fracs_nopart_fold, | |
"loss_e_fracs_true_nopart": loss_E_fracs_true_nopart_fold, | |
"loss_e_fracs": loss_E_fracs_fold, | |
"loss_e_fracs_true": loss_E_fracs_true_fold, | |
"part_E_true": part_E_true_fold, | |
"part_PID_true": part_PID_true_fold, | |
"reco_counts": reco_counts, | |
"non_reco_counts": non_reco_counts, | |
"total_counts": total_counts, | |
} | |
file_to_save = os.path.join(save_ckpt_to_folder, "temp_ckpt" + ".pkl") | |
with open(file_to_save, "wb") as f: | |
pickle.dump(obj, f) | |
if num_batches >= total_num_batches: | |
break | |
# flatten the lists | |
if save_ckpt_to_folder is not None: | |
return | |
loss_E_fracs = [item for sublist in loss_E_fracs for item in sublist] | |
loss_E_fracs = torch.concat(loss_E_fracs).flatten() | |
loss_E_fracs_true = [item for sublist in loss_E_fracs_true for item in sublist] | |
loss_E_fracs_true = torch.concat(loss_E_fracs_true).flatten() | |
part_E_true = [item for sublist in part_E_true for item in sublist] | |
part_E_true = torch.concat(part_E_true).flatten() | |
part_PID_true = [item for sublist in part_PID_true for item in sublist] | |
part_PID_true = torch.concat(part_PID_true).flatten() | |
loss_E_fracs_nopart = [ | |
item for sublist in loss_E_fracs_nopart for item in sublist | |
] | |
loss_E_fracs_true_nopart = [ | |
item for sublist in loss_E_fracs_true_nopart for item in sublist | |
] | |
return { | |
"loss_e_fracs": loss_E_fracs, | |
"loss_e_fracs_true": loss_E_fracs_true, | |
"loss_e_fracs_nopart": loss_E_fracs_nopart, | |
"loss_e_fracs_true_nopart": loss_E_fracs_true_nopart, | |
"betas": betas_list, | |
"part_E_true": part_E_true, | |
"part_PID_true": part_PID_true, | |
"reco_counts": reco_counts, | |
"non_reco_counts": non_reco_counts, | |
"total_counts": total_counts, | |
} | |
def inference(model, test_loader, dev): | |
""" | |
Similar to evaluate_regression, but without the ground truth labels. | |
""" | |
model.eval() | |
num_batches = 0 | |
count = 0 | |
results = [] | |
start_time = time.time() | |
with torch.no_grad(): | |
with tqdm.tqdm(test_loader) as tq: | |
for batch_g, _ in tq: | |
batch_g = batch_g.to(dev) | |
model_output = model(batch_g) | |
# preds = model_output.squeeze().float() | |
preds = model.mod.object_condensation_inference(batch_g, model_output) | |
num_batches += 1 | |
results.append(preds) | |
time_diff = time.time() - start_time | |
_logger.info( | |
"Processed %d entries in total (avg. speed %.1f entries/s)" | |
% (count, count / time_diff) | |
) | |
return results | |
#! create output graph with shower id ndata and store it for each event | |
# if args.store_output: | |
# print("calculating clustering and matching showers") | |
# if step == 0 and local_rank == 0: | |
# create_and_store_graph_output( | |
# batch_g, | |
# model_output, | |
# y, | |
# local_rank, | |
# step, | |
# epoch, | |
# path_save=args.model_prefix + "/showers_df", | |
# store=True, | |
# ) | |
# losses_cpu = [ | |
# x.detach().to("cpu") if isinstance(x, torch.Tensor) else x | |
# for x in losses | |
# ] | |
# all_val_losses.append(losses_cpu) | |
# all_val_loss.append(loss.detach().to("cpu").item()) | |
# pid_true, pid_pred = torch.cat( | |
# [torch.tensor(x[7]) for x in all_val_losses] | |
# ), torch.cat([torch.tensor(x[8]) for x in all_val_losses]) | |
# pid_true, pid_pred = pid_true.tolist(), pid_pred.tolist() | |
# , step=step) | |
# if clust_loss_only and calc_e_frac_loss: | |
# wandb.log( | |
# { | |
# "loss e frac val": loss_E_frac, | |
# "loss e frac true val": loss_E_frac_true, | |
# } | |
# ) | |
# ks = sorted(list(all_val_losses[0][9].keys())) | |
# concatenated = {} | |
# for key in ks: | |
# concatenated[key] = np.concatenate([x[9][key] for x in all_val_losses]) | |
# tables = {} | |
# for key in ks: | |
# tables[key] = concatenated[ | |
# key | |
# ] # wandb.Table(data=[[x] for x in concatenated[key]], columns=[key]) | |
# wandb.log( | |
# { | |
# "val " + key: wandb.Histogram(clip_list(tables[key]), num_bins=100) | |
# for key in ks | |
# } | |
# ) # , step=step) | |
# scores = np.concatenate(scores) | |
# labels = {k: _concat(v) for k, v in labels.items()} | |
# metric_results = evaluate_metrics(labels[data_config.label_names[0]], scores, eval_metrics=eval_metrics) | |
# _logger.info('Evaluation metrics: \n%s', '\n'.join( | |
# [' - %s: \n%s' % (k, str(v)) for k, v in metric_results.items()])) | |
def plot_regression_resolution(model, test_loader, dev, **kwargs): | |
model.eval() | |
results = [] # resolution results | |
pid_classification_results = [] | |
with torch.no_grad(): | |
with tqdm.tqdm(test_loader) as tq: | |
for batch_g, y in tq: | |
batch_g = batch_g.to(dev) | |
if args.loss_regularization: | |
model_output, loss_regularizing_neig = model(batch_g) | |
else: | |
model_output = model(batch_g) | |
resolutions, pid_true, pid_pred = model.mod.object_condensation_loss2( | |
batch_g, | |
model_output, | |
y, | |
return_resolution=True, | |
q_min=args.qmin, | |
frac_clustering_loss=0, | |
use_average_cc_pos=args.use_average_cc_pos, | |
hgcalloss=args.hgcalloss, | |
) | |
results.append(resolutions) | |
pid_classification_results.append((pid_true, pid_pred)) | |
result_dict = {} | |
for key in results[0]: | |
result_dict[key] = np.concatenate([r[key] for r in results]) | |
result_dict["event_by_event_accuracy"] = [ | |
accuracy_score(pid_true.argmax(dim=0), pid_pred.argmax(dim=0)) | |
for pid_true, pid_pred in pid_classification_results | |
] | |
# just plot all for now | |
result = {} | |
for key in results[0]: | |
data = result_dict[key] | |
fig, ax = plt.subplots() | |
ax.hist(data, bins=100, range=(-1.5, 1.5), histtype="step", label=key) | |
ax.set_xlabel("resolution") | |
ax.set_ylabel("count") | |
ax.legend() | |
result[key] = fig | |
conf_mat = confusion_matrix(pid_true.argmax(dim=0), pid_pred.argmax(dim=0)) | |
# confusion matrix | |
fig, ax = plt.subplots(figsize=(7.5, 7.5)) | |
# add onehot_particle_arr as class names | |
class_names = onehot_particles_arr | |
im = ax.matshow(conf_mat, cmap=plt.cm.Blues) | |
ax.set_xticks(np.arange(len(class_names)), class_names, rotation=45) | |
ax.set_yticks(np.arange(len(class_names)), class_names) | |
result["PID_confusion_matrix"] = fig | |
return result | |
# if args.loss_regularization: | |
# wandb.log({"loss regul neigh": loss_regularizing_neig}) | |
# wandb.log({"loss ll": loss_ll}) | |
# if (num_batches - 1) % 100 == 0: | |
# if clust_loss_only: | |
# clust_space_dim = 3 # model.mod.output_dim - 1 | |
# else: | |
# clust_space_dim = model.mod.output_dim - 28 | |
# bj = torch.sigmoid( | |
# torch.reshape(model_output[:, clust_space_dim], [-1, 1]) | |
# ) # 3: betas | |
# xj = model_output[:, 0:clust_space_dim] # xj: cluster space coords | |
# # assert len(bj) == len(xj) | |
# # if model.mod.clust_space_norm == "twonorm": | |
# # xj = torch.nn.functional.normalize( | |
# # xj, dim=1 | |
# # ) # 0, 1, 2: cluster space coords | |
# # elif model.mod.clust_space_norm == "tanh": | |
# # xj = torch.tanh(xj) | |
# # elif model.mod.clust_space_norm == "none": | |
# # pass | |
# bj = bj.clip(0.0, 1 - 1e-4) | |
# q = bj.arctanh() ** 2 + args.qmin | |
# assert q.shape[0] == xj.shape[0] | |
# assert batch_g.ndata["h"].shape[0] == xj.shape[0] | |
# fig, ax = plot_clust( | |
# batch_g, | |
# q, | |
# xj, | |
# title_prefix="train ep. {}, batch {}".format( | |
# epoch, num_batches | |
# ), | |
# y=y, | |
# betas=bj, | |
# ) | |
# wandb.log({"clust": wandb.Image(fig)}) | |
# fig.clf() | |
# # if (num_batches - 1) % 500 == 0: | |
# # wandb.log( | |
# # { | |
# # "conf_mat_train": wandb.plot.confusion_matrix( | |
# # y_true=pid_true, | |
# # preds=pid_pred, | |
# # class_names=class_names, | |
# # ) | |
# # } | |
# # ) | |
# ks = sorted(list(losses[9].keys())) | |
# losses_cpu = [ | |
# x.detach().to("cpu") if isinstance(x, torch.Tensor) else x | |
# for x in losses | |
# ] | |
# tables = {} | |
# for key in ks: | |
# tables[key] = losses[9][ | |
# key | |
# ] # wandb.Table(data=[[x] for x in losses[9][key]], columns=[key]) | |
# if local_rank == 0: | |
# wandb.log( | |
# { | |
# key: wandb.Histogram(clip_list(tables[key]), num_bins=100) | |
# for key, val in losses_cpu[9].items() | |
# } | |
# ) # , step=step_count) | |
# return loss_epoch_total, losses_epoch_total | |