Spaces:
Sleeping
Sleeping
from typing import Tuple, Union | |
import numpy as np | |
import torch | |
from torch_scatter import scatter_max, scatter_add, scatter_mean | |
from src.layers.loss_fill_space_torch import LLFillSpace | |
def safe_index(arr, index): | |
# One-hot index (or zero if it's not in the array) | |
if index not in arr: | |
return 0 | |
else: | |
return arr.index(index) + 1 | |
def assert_no_nans(x): | |
""" | |
Raises AssertionError if there is a nan in the tensor | |
""" | |
if torch.isnan(x).any(): | |
print(x) | |
assert not torch.isnan(x).any() | |
# FIXME: Use a logger instead of this | |
DEBUG = False | |
def debug(*args, **kwargs): | |
if DEBUG: | |
print(*args, **kwargs) | |
def calc_LV_Lbeta( | |
original_coords, | |
g, | |
distance_threshold, | |
beta: torch.Tensor, | |
cluster_space_coords: torch.Tensor, # Predicted by model | |
cluster_index_per_event: torch.Tensor, # Truth hit->cluster index, e.g. [0, 1, 1, 0, 1, -1, 0, 1, 1] | |
batch: torch.Tensor, # E.g. [0, 0, 0, 0, 1, 1, 1, 1, 1] | |
# From here on just parameters | |
qmin: float = 0.1, | |
s_B: float = 1.0, | |
noise_cluster_index: int = 0, # cluster_index entries with this value are noise/noise | |
beta_stabilizing="soft_q_scaling", | |
huberize_norm_for_V_attractive=False, | |
beta_term_option="paper", | |
frac_combinations=0, # fraction of the all possible pairs to be used for the clustering loss | |
attr_weight=1.0, | |
repul_weight=1.0, | |
use_average_cc_pos=0.0, | |
loss_type="hgcalimplementation", | |
tracking=False, | |
dis=False, | |
beta_type="default", | |
noise_logits=None, | |
lorentz_norm=False, | |
spatial_part_only=False, | |
) -> Union[Tuple[torch.Tensor, torch.Tensor], dict]: | |
""" | |
Calculates the L_V and L_beta object condensation losses. | |
Concepts: | |
- A hit belongs to exactly one cluster (cluster_index_per_event is (n_hits,)), | |
and to exactly one event (batch is (n_hits,)) | |
- A cluster index of `noise_cluster_index` means the cluster is a noise cluster. | |
There is typically one noise cluster per event. Any hit in a noise cluster | |
is a 'noise hit'. A hit in an object is called a 'signal hit' for lack of a | |
better term. | |
- An 'object' is a cluster that is *not* a noise cluster. | |
beta_stabilizing: Choices are ['paper', 'clip', 'soft_q_scaling']: | |
paper: beta is sigmoid(model_output), q = beta.arctanh()**2 + qmin | |
clip: beta is clipped to 1-1e-4, q = beta.arctanh()**2 + qmin | |
soft_q_scaling: beta is sigmoid(model_output), q = (clip(beta)/1.002).arctanh()**2 + qmin | |
huberize_norm_for_V_attractive: Huberizes the norms when used in the attractive potential | |
beta_term_option: Choices are ['paper', 'short-range-potential']: | |
Choosing 'short-range-potential' introduces a short range potential around high | |
beta points, acting like V_attractive. | |
Note this function has modifications w.r.t. the implementation in 2002.03605: | |
- The norms for V_repulsive are now Gaussian (instead of linear hinge) | |
- Noise_logits: If set to an array, it is the output of the noise classifier (whether a particle belongs to a jet or not) | |
""" | |
# remove dummy rows added for dataloader #TODO think of better way to do this | |
device = beta.device | |
if torch.isnan(beta).any(): | |
print("There are nans in beta! L198", len(beta[torch.isnan(beta)])) | |
beta = torch.nan_to_num(beta, nan=0.0) | |
assert_no_nans(beta) | |
# ________________________________ | |
# Calculate a bunch of needed counts and indices locally | |
# cluster_index: unique index over events | |
# E.g. cluster_index_per_event=[ 0, 0, 1, 2, 0, 0, 1], batch=[0, 0, 0, 0, 1, 1, 1] | |
# -> cluster_index=[ 0, 0, 1, 2, 3, 3, 4 ] | |
cluster_index, n_clusters_per_event = batch_cluster_indices( | |
cluster_index_per_event, batch | |
) | |
n_clusters = n_clusters_per_event.sum() | |
n_hits, cluster_space_dim = cluster_space_coords.size() | |
batch_size = batch.max().detach().cpu().item() + 1 | |
n_hits_per_event = scatter_count(batch) | |
# Index of cluster -> event (n_clusters,) | |
batch_cluster = scatter_counts_to_indices(n_clusters_per_event) | |
# Per-hit boolean, indicating whether hit is sig or noise | |
is_noise = cluster_index_per_event == noise_cluster_index | |
is_sig = ~is_noise | |
n_hits_sig = is_sig.sum() | |
n_sig_hits_per_event = scatter_count(batch[is_sig]) | |
# Per-cluster boolean, indicating whether cluster is an object or noise | |
is_object = scatter_max(is_sig.long(), cluster_index)[0].bool() | |
is_noise_cluster = ~is_object | |
# FIXME: This assumes noise_cluster_index == 0!! | |
# Not sure how to do this in a performant way in case noise_cluster_index != 0 | |
if noise_cluster_index != 0: | |
raise NotImplementedError | |
object_index_per_event = cluster_index_per_event[is_sig] - 1 | |
object_index, n_objects_per_event = batch_cluster_indices( | |
object_index_per_event, batch[is_sig] | |
) | |
n_hits_per_object = scatter_count(object_index) | |
# print("n_hits_per_object", n_hits_per_object) | |
batch_object = batch_cluster[is_object] | |
n_objects = is_object.sum() | |
assert object_index.size() == (n_hits_sig,) | |
assert is_object.size() == (n_clusters,) | |
assert torch.all(n_hits_per_object > 0) | |
assert object_index.max() + 1 == n_objects | |
# ________________________________ | |
# L_V term | |
# Calculate q | |
if beta_type == "default": | |
if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
q = (beta.clip(0.0, 1 - 1e-4).arctanh() / 1.01) ** 2 + qmin | |
elif beta_stabilizing == "paper": | |
q = beta.arctanh() ** 2 + qmin | |
elif beta_stabilizing == "clip": | |
beta = beta.clip(0.0, 1 - 1e-4) | |
q = beta.arctanh() ** 2 + qmin | |
elif beta_stabilizing == "soft_q_scaling": | |
q = (beta.clip(0.0, 1 - 1e-4) / 1.002).arctanh() ** 2 + qmin | |
else: | |
raise ValueError(f"beta_stablizing mode {beta_stabilizing} is not known") | |
elif beta_type == "pt": | |
q = beta | |
elif beta_type == "pt+bc": | |
q = beta | |
#if beta_type in ["pt", "pt+bc"]: | |
# q[q<0.5] = 0.5 # cap the q | |
assert_no_nans(q) | |
assert q.device == device | |
assert q.size() == (n_hits,) | |
# Calculate q_alpha, the max q per object, and the indices of said maxima | |
# assert hit_energies.shape == q.shape | |
# q_alpha, index_alpha = scatter_max(hit_energies[is_sig], object_index) | |
q_alpha, index_alpha = scatter_max(q[is_sig], object_index) | |
assert q_alpha.size() == (n_objects,) | |
# Get the cluster space coordinates and betas for these maxima hits too | |
x_alpha = cluster_space_coords[is_sig][index_alpha] | |
#x_alpha_original = original_coords[is_sig][index_alpha] | |
if use_average_cc_pos > 0: | |
#! this is a func of beta and q so maybe we could also do it with only q | |
x_alpha_sum = scatter_add( | |
q[is_sig].view(-1, 1).repeat(1, 3) * cluster_space_coords[is_sig], | |
object_index, | |
dim=0, | |
) # * beta[is_sig].view(-1, 1).repeat(1, 3) | |
qbeta_alpha_sum = scatter_add(q[is_sig], object_index) + 1e-9 # * beta[is_sig] | |
div_fac = 1 / qbeta_alpha_sum | |
div_fac = torch.nan_to_num(div_fac, nan=0) | |
x_alpha_mean = torch.mul(x_alpha_sum, div_fac.view(-1, 1).repeat(1, 3)) | |
x_alpha = use_average_cc_pos * x_alpha_mean + (1 - use_average_cc_pos) * x_alpha | |
if dis: | |
phi_sum = scatter_add( | |
beta[is_sig].view(-1) * distance_threshold[is_sig].view(-1), | |
object_index, | |
dim=0, | |
) | |
phi_alpha_sum = scatter_add(beta[is_sig].view(-1), object_index) + 1e-9 | |
phi_alpha = phi_sum / phi_alpha_sum | |
beta_alpha = beta[is_sig][index_alpha] | |
assert x_alpha.size() == (n_objects, cluster_space_dim) | |
assert beta_alpha.size() == (n_objects,) | |
# Connectivity matrix from hit (row) -> cluster (column) | |
# Index to matrix, e.g.: | |
# [1, 3, 1, 0] --> [ | |
# [0, 1, 0, 0], | |
# [0, 0, 0, 1], | |
# [0, 1, 0, 0], | |
# [1, 0, 0, 0] | |
# ] | |
M = torch.nn.functional.one_hot(cluster_index).long() | |
# Anti-connectivity matrix; be sure not to connect hits to clusters in different events! | |
M_inv = get_inter_event_norms_mask(batch, n_clusters_per_event) - M | |
# Throw away noise cluster columns; we never need them | |
M = M[:, is_object] | |
M_inv = M_inv[:, is_object] | |
assert M.size() == (n_hits, n_objects) | |
assert M_inv.size() == (n_hits, n_objects) | |
# Calculate all norms | |
# Warning: Should not be used without a mask! | |
# Contains norms between hits and objects from different events | |
# (n_hits, 1, cluster_space_dim) - (1, n_objects, cluster_space_dim) | |
# gives (n_hits, n_objects, cluster_space_dim) | |
norms = (cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)).norm(dim=-1) | |
assert norms.size() == (n_hits, n_objects) | |
L_clusters = torch.tensor(0.0).to(device) | |
if frac_combinations != 0: | |
L_clusters = L_clusters_calc( | |
batch, cluster_space_coords, cluster_index, frac_combinations, q | |
) | |
# ------- | |
# Attractive potential term | |
# First get all the relevant norms: We only want norms of signal hits | |
# w.r.t. the object they belong to, i.e. no noise hits and no noise clusters. | |
# First select all norms of all signal hits w.r.t. all objects, mask out later | |
if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
# if dis: | |
# N_k = torch.sum(M, dim=0) # number of hits per object | |
# norms = torch.sum( | |
# torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)), | |
# dim=-1, | |
# ) | |
# norms_att = norms[is_sig] | |
# norms_att = norms_att / (2 * phi_alpha.unsqueeze(0) ** 2 + 1e-6) | |
# #! att func as in line 159 of object condensation | |
# norms_att = torch.log( | |
# torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att + 1 | |
# ) | |
N_k = torch.sum(M, dim=0) # number of hits per object | |
if lorentz_norm: | |
diff = cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0) | |
norms = diff[:, :, 0]**2 - torch.sum(diff[:, :, 1:] ** 2, dim=-1) | |
norms = norms.abs() ## ??? Why is this needed? wrong convention? | |
#print("Norms", norms[:15]) | |
else: | |
if spatial_part_only: | |
norms = torch.sum( | |
torch.square(cluster_space_coords[:, 1:4].unsqueeze(1) - x_alpha[:, 1:4].unsqueeze(0)), | |
dim=-1, | |
) | |
else: | |
norms = torch.sum( | |
torch.square(cluster_space_coords.unsqueeze(1) - x_alpha.unsqueeze(0)), | |
dim=-1, | |
) # Take the norm squared | |
norms_att = norms[is_sig] | |
#! att func as in line 159 of object condensation | |
norms_att = torch.log( | |
torch.exp(torch.Tensor([1]).to(norms_att.device)) * norms_att / 2 + 1 | |
) | |
elif huberize_norm_for_V_attractive: | |
norms_att = norms[is_sig] | |
# Huberized version (linear but times 4) | |
# Be sure to not move 'off-diagonal' away from zero | |
# (i.e. norms of hits w.r.t. clusters they do _not_ belong to) | |
norms_att = huber(norms_att + 1e-5, 4.0) | |
else: | |
norms_att = norms[is_sig] | |
# Paper version is simply norms squared (no need for mask) | |
norms_att = norms_att**2 | |
assert norms_att.size() == (n_hits_sig, n_objects) | |
# Now apply the mask to keep only norms of signal hits w.r.t. to the object | |
# they belong to | |
norms_att *= M[is_sig] | |
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum over events | |
if loss_type == "hgcalimplementation": | |
# Final potential term | |
# (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
# hit_type = (g.ndata["hit_type"][is_sig].view(-1)==3)*4+1 #weight 5 for hadronic hits, 1 for | |
# tracks = g.ndata["hit_type"][is_sig]==1 | |
# hit_type[tracks] = 250 | |
# total_sum_hits_types = scatter_add(hit_type.view(-1), object_index) | |
V_attractive = q[is_sig].unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att | |
assert V_attractive.size() == (n_hits_sig, n_objects) | |
#! each shower is account for separately | |
V_attractive = V_attractive.sum(dim=0) # K objects | |
#! divide by the number of accounted points | |
V_attractive = V_attractive.view(-1) / (N_k.view(-1) + 1e-3) | |
# V_attractive = V_attractive.view(-1) / (total_sum_hits_types.view(-1) + 1e-3) | |
# L_V_attractive = torch.mean(V_attractive) | |
## multiply by a weight that depends on the energy of the shower: | |
# print("e_hits", e_hits) | |
# print("weight_att", weight_att) | |
# L_V_attractive = torch.sum(V_attractive*weight_att) | |
L_V_attractive = torch.mean(V_attractive) | |
# L_V_attractive = L_V_attractive / torch.sum(weight_att) | |
L_V_attractive_2 = torch.sum(V_attractive) | |
elif loss_type == "vrepweighted": | |
if tracking: | |
# weight the vtx hits inside the shower | |
V_attractive = ( | |
g.ndata["weights"][is_sig].unsqueeze(-1) | |
* q[is_sig].unsqueeze(-1) | |
* q_alpha.unsqueeze(0) | |
* norms_att | |
) | |
assert V_attractive.size() == (n_hits_sig, n_objects) | |
V_attractive = V_attractive.sum(dim=0) # K objects | |
L_V_attractive = torch.mean(V_attractive.view(-1)) | |
else: | |
# # weight per hit per shower to compensate for ecal hcal unbalance in hadronic showers | |
# ecal_hits = scatter_add( | |
# 1 * (g.ndata["hit_type"][is_sig] == 2), object_index | |
# ) | |
# hcal_hits = scatter_add( | |
# 1 * (g.ndata["hit_type"][is_sig] == 3), object_index | |
# ) | |
# weights = torch.ones_like(g.ndata["hit_type"][is_sig]) | |
# weight_ecal_per_object = 1.0 * ecal_hits.clone() + 1 | |
# weight_hcal_per_object = 1.0 * ecal_hits.clone() + 1 | |
# mask = (ecal_hits > 2) * (hcal_hits > 2) | |
# weight_ecal_per_object[mask] = (ecal_hits + hcal_hits)[mask] / ( | |
# 2 * ecal_hits | |
# )[mask] | |
# weight_hcal_per_object[mask] = (ecal_hits + hcal_hits)[mask] / ( | |
# 2 * hcal_hits | |
# )[mask] | |
# weights[g.ndata["hit_type"][is_sig] == 2] = weight_ecal_per_object[ | |
# object_index | |
# ] | |
# weights[g.ndata["hit_type"][is_sig] == 3] = weight_hcal_per_object[ | |
# object_index | |
# ] | |
# # weight with an energy log of the hits | |
# e_hits = g.ndata["e_hits"][is_sig].view(-1) | |
# p_hits = g.ndata["h"][:, -1][is_sig].view(-1) | |
# log_scale_s = torch.log(e_hits + p_hits) + 10 | |
# e_sum_hits = scatter_add(log_scale_s, object_index) | |
# # need to take out the weight of alpha otherwise it won't add up to 1 | |
# e_sum_hits = e_sum_hits - (log_scale_s[index_alpha]) | |
# e_rel = (log_scale_s) / e_sum_hits[object_index] | |
# weight of the hit depending on the radial distance: | |
# this weight should help to seed | |
# weight_radial_distance = torch.exp( | |
# -g.ndata["radial_distance"][is_sig] / 100 | |
# ) | |
# weight_per_object = scatter_add(weight_radial_distance, object_index) | |
# weight_radial_distance = ( | |
# weight_radial_distance / weight_per_object[object_index] | |
# ) | |
V_attractive = ( | |
q[is_sig].unsqueeze(-1) ## weight_radial_distance.unsqueeze(-1) | |
* q_alpha.unsqueeze(0) | |
* norms_att | |
) | |
# weight modified showers with a higher weight | |
modified_showers = scatter_max(g.ndata["hit_link_modified"], object_index)[ | |
0 | |
] | |
n_modified = torch.sum(modified_showers) | |
weight_modified = len(modified_showers) / (2 * n_modified) | |
weight_unmodified = len(modified_showers) / ( | |
2 * (len(modified_showers) - n_modified) | |
) | |
modified_showers[modified_showers > 0] = weight_modified | |
modified_showers[modified_showers == 0] = weight_unmodified | |
assert V_attractive.size() == (n_hits_sig, n_objects) | |
V_attractive = V_attractive.sum(dim=0) # K objects | |
L_V_attractive = torch.sum( | |
modified_showers.view(-1) * V_attractive.view(-1) | |
) / len(modified_showers) | |
else: | |
# Final potential term | |
# (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
V_attractive = q[is_sig].unsqueeze(-1) * q_alpha.unsqueeze(0) * norms_att | |
assert V_attractive.size() == (n_hits_sig, n_objects) | |
#! in comparison this works per hit | |
V_attractive = ( | |
scatter_add(V_attractive.sum(dim=0), batch_object) / n_hits_per_event | |
) | |
assert V_attractive.size() == (batch_size,) | |
L_V_attractive = V_attractive.sum() | |
# ------- | |
# Repulsive potential term | |
# Get all the relevant norms: We want norms of any hit w.r.t. to | |
# objects they do *not* belong to, i.e. no noise clusters. | |
# We do however want to keep norms of noise hits w.r.t. objects | |
# Power-scale the norms: Gaussian scaling term instead of a cone | |
# Mask out the norms of hits w.r.t. the cluster they belong to | |
if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
if dis: | |
norms = norms / (2 * phi_alpha.unsqueeze(0) ** 2 + 1e-6) | |
norms_rep = torch.exp(-(norms)) * M_inv | |
norms_rep2 = torch.exp(-(norms) * 10) * M_inv | |
else: | |
norms_rep = torch.exp(-(norms) / 2) * M_inv | |
# norms_rep2 = torch.exp(-(norms) * 10) * M_inv | |
norms_rep2 = torch.exp(-(norms) * 10) * M_inv | |
else: | |
norms_rep = torch.exp(-4.0 * norms**2) * M_inv | |
# (n_sig_hits, 1) * (1, n_objects) * (n_sig_hits, n_objects) | |
V_repulsive = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep | |
# No need to apply a V = max(0, V); by construction V>=0 | |
assert V_repulsive.size() == (n_hits, n_objects) | |
# Sum over hits, then sum per event, then divide by n_hits_per_event, then sum up events | |
nope = n_objects_per_event - 1 | |
nope[nope == 0] = 1 | |
if loss_type == "hgcalimplementation" or loss_type == "vrepweighted": | |
#! sum each object repulsive terms | |
L_V_repulsive = V_repulsive.sum(dim=0) # size number of objects | |
number_of_repulsive_terms_per_object = torch.sum(M_inv, dim=0) | |
L_V_repulsive = L_V_repulsive.view( | |
-1 | |
) / number_of_repulsive_terms_per_object.view(-1) | |
V_repulsive2 = q.unsqueeze(1) * q_alpha.unsqueeze(0) * norms_rep2 | |
L_V_repulsive2 = V_repulsive2.sum(dim=0) # size number of objects | |
L_V_repulsive2 = L_V_repulsive2.view(-1) | |
L_V_attractive_2 = L_V_attractive_2.view(-1) | |
# if not tracking: | |
# #! add to terms function (divide by total number of showers per event) | |
# # L_V_repulsive = scatter_add(L_V_repulsive, object_index) / n_objects | |
# per_shower_weight = torch.exp(1 / (e_particles_pred_per_object + 0.4)) | |
# soft_m = torch.nn.Softmax(dim=0) | |
# per_shower_weight = soft_m(per_shower_weight) * len(L_V_repulsive) | |
# L_V_repulsive = torch.mean(L_V_repulsive * per_shower_weight) | |
# else: | |
# if tracking: | |
# L_V_repulsive = torch.mean(L_V_repulsive * per_shower_weight) | |
# else: | |
if loss_type == "vrepweighted": | |
L_V_repulsive = torch.sum( | |
modified_showers.view(-1) * L_V_repulsive.view(-1) | |
) / len(modified_showers) | |
L_V_repulsive2 = torch.sum( | |
modified_showers.view(-1) * L_V_repulsive2.view(-1) | |
) / len(modified_showers) | |
else: | |
L_V_repulsive = torch.mean(L_V_repulsive) | |
L_V_repulsive2 = torch.mean(L_V_repulsive2) | |
else: | |
L_V_repulsive = ( | |
scatter_add(V_repulsive.sum(dim=0), batch_object) | |
/ (n_hits_per_event * nope) | |
).sum() | |
L_V = ( | |
attr_weight * L_V_attractive + repul_weight * L_V_repulsive | |
) | |
n_noise_hits_per_event = scatter_count(batch[is_noise]) | |
n_noise_hits_per_event[n_noise_hits_per_event == 0] = 1 | |
L_beta_noise = ( | |
s_B | |
* ( | |
(scatter_add(beta[is_noise], batch[is_noise])) / n_noise_hits_per_event | |
).sum() | |
) | |
if loss_type == "hgcalimplementation": | |
beta_per_object_c = scatter_add(beta[is_sig], object_index) | |
beta_alpha = beta[is_sig][index_alpha] | |
L_beta_sig = torch.mean( | |
1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1) | |
) | |
L_beta_noise = L_beta_noise / 4 | |
# ? note: the training that worked quite well was dividing this by the batch size (1/4) | |
elif loss_type == "vrepweighted": | |
# version one: | |
beta_per_object_c = scatter_add(beta[is_sig], object_index) | |
beta_alpha = beta[is_sig][index_alpha] | |
L_beta_sig = 1 - beta_alpha + 1 - torch.clip(beta_per_object_c, 0, 1) | |
L_beta_sig = torch.sum(L_beta_sig.view(-1) * modified_showers.view(-1)) | |
L_beta_sig = L_beta_sig / len(modified_showers) | |
L_beta_noise = L_beta_noise / batch_size | |
# ? note: the training that worked quite well was dividing this by the batch size (1/4) | |
elif beta_term_option == "paper": | |
beta_alpha = beta[is_sig][index_alpha] | |
L_beta_sig = torch.sum( # maybe 0.5 for less aggressive loss | |
scatter_add((1 - beta_alpha), batch_object) / n_objects_per_event | |
) | |
# print("L_beta_sig", L_beta_sig / batch_size) | |
# beta_exp = beta[is_sig] | |
# beta_exp[index_alpha] = 0 | |
# # L_exp = torch.mean(beta_exp) | |
# beta_exp = torch.exp(0.5 * beta_exp) | |
# L_exp = torch.mean(scatter_add(beta_exp, batch) / n_hits_per_event) | |
elif beta_term_option == "short-range-potential": | |
# First collect the norms: We only want norms of hits w.r.t. the object they | |
# belong to (like in V_attractive) | |
# Apply transformation first, and then apply mask to keep only the norms we want, | |
# then sum over hits, so the result is (n_objects,) | |
norms_beta_sig = (1.0 / (20.0 * norms[is_sig] ** 2 + 1.0) * M[is_sig]).sum( | |
dim=0 | |
) | |
assert torch.all(norms_beta_sig >= 1.0) and torch.all( | |
norms_beta_sig <= n_hits_per_object | |
) | |
# Subtract from 1. to remove self interaction, divide by number of hits per object | |
norms_beta_sig = (1.0 - norms_beta_sig) / n_hits_per_object | |
assert torch.all(norms_beta_sig >= -1.0) and torch.all(norms_beta_sig <= 0.0) | |
norms_beta_sig *= beta_alpha | |
# Conclusion: | |
# lower beta --> higher loss (less negative) | |
# higher norms --> higher loss | |
# Sum over objects, divide by number of objects per event, then sum over events | |
L_beta_norms_term = ( | |
scatter_add(norms_beta_sig, batch_object) / n_objects_per_event | |
).sum() | |
assert L_beta_norms_term >= -batch_size and L_beta_norms_term <= 0.0 | |
# Logbeta term: Take -.2*torch.log(beta_alpha[is_object]+1e-9), sum it over objects, | |
# divide by n_objects_per_event, then sum over events (same pattern as above) | |
# lower beta --> higher loss | |
L_beta_logbeta_term = ( | |
scatter_add(-0.2 * torch.log(beta_alpha + 1e-9), batch_object) | |
/ n_objects_per_event | |
).sum() | |
# Final L_beta term | |
L_beta_sig = L_beta_norms_term + L_beta_logbeta_term | |
else: | |
valid_options = ["paper", "short-range-potential"] | |
raise ValueError( | |
f'beta_term_option "{beta_term_option}" is not valid, choose from {valid_options}' | |
) | |
L_beta = L_beta_noise + L_beta_sig | |
if beta_type == "pt" or beta_type == "pt+bc": | |
L_beta = torch.tensor(0.) | |
L_beta_sig = torch.tensor(0.) | |
L_beta_noise = torch.tensor(0.) | |
#L_alpha_coordinates = torch.mean(torch.norm(x_alpha_original - x_alpha, p=2, dim=1)) | |
x_original = original_coords / torch.norm(original_coords, p=2, dim=1).view(-1, 1) | |
x_virtual = cluster_space_coords / torch.norm(cluster_space_coords, p=2, dim=1).view(-1, 1) | |
loss_coord = torch.mean(torch.norm(x_original - x_virtual, p=2, dim=1)) # We just compare the direction | |
if beta_type == "pt+bc": | |
assert noise_logits is not None | |
y_true_noise = 1 - is_noise.float() | |
num_positives = torch.sum(y_true_noise).item() | |
num_negatives = len(y_true_noise) - num_positives | |
num_all = len(y_true_noise) | |
# Compute weights | |
pos_weight = num_all / num_positives if num_positives > 0 else 0 | |
neg_weight = num_all / num_negatives if num_negatives > 0 else 0 | |
weight = pos_weight * y_true_noise + neg_weight * (1 - y_true_noise) | |
L_bc = torch.nn.BCELoss(weight=weight)( | |
noise_logits, 1-is_noise.float() | |
) | |
#if torch.isnan(L_beta / batch_size): | |
# print("isnan!!!") | |
# print(L_beta, batch_size) | |
# print("L_beta_noise", L_beta_noise) | |
# print("L_beta_sig", L_beta_sig) | |
result = { | |
"loss_potential": L_V, # 0 | |
"loss_beta": L_beta, | |
"loss_beta_sig": L_beta_sig, # signal part of the betas | |
"loss_beta_noise": L_beta_noise, # noise part of the betas | |
"loss_attractive": L_V_attractive, | |
"loss_repulsive": L_V_repulsive, | |
"loss_coord": loss_coord, | |
} | |
if beta_type == "pt+bc": | |
result["loss_noise_classification"] = L_bc | |
return result | |
def huber(d, delta): | |
""" | |
See: https://en.wikipedia.org/wiki/Huber_loss#Definition | |
Multiplied by 2 w.r.t Wikipedia version (aligning with Jan's definition) | |
""" | |
return torch.where( | |
torch.abs(d) <= delta, d**2, 2.0 * delta * (torch.abs(d) - delta) | |
) | |
def batch_cluster_indices( | |
cluster_id: torch.Tensor, batch: torch.Tensor | |
) -> Tuple[torch.LongTensor, torch.LongTensor]: | |
""" | |
Turns cluster indices per event to an index in the whole batch | |
Example: | |
cluster_id = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1]) | |
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2]) | |
--> | |
offset = torch.LongTensor([0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 5, 5, 5]) | |
output = torch.LongTensor([0, 0, 1, 1, 2, 3, 3, 4, 4, 4, 5, 5, 6]) | |
""" | |
device = cluster_id.device | |
assert cluster_id.device == batch.device | |
# Count the number of clusters per entry in the batch | |
n_clusters_per_event = scatter_max(cluster_id, batch, dim=-1)[0] + 1 | |
# Offsets are then a cumulative sum | |
offset_values_nozero = n_clusters_per_event[:-1].cumsum(dim=-1) | |
# Prefix a zero | |
offset_values = torch.cat((torch.zeros(1, device=device), offset_values_nozero)) | |
# Fill it per hit | |
offset = torch.gather(offset_values, 0, batch).long() | |
return offset + cluster_id, n_clusters_per_event | |
def get_clustering_np( | |
betas: np.array, X: np.array, tbeta: float = 0.1, td: float = 1.0 | |
) -> np.array: | |
""" | |
Returns a clustering of hits -> cluster_index, based on the GravNet model | |
output (predicted betas and cluster space coordinates) and the clustering | |
parameters tbeta and td. | |
Takes numpy arrays as input. | |
""" | |
n_points = betas.shape[0] | |
select_condpoints = betas > tbeta | |
# Get indices passing the threshold | |
indices_condpoints = np.nonzero(select_condpoints)[0] | |
# Order them by decreasing beta value | |
indices_condpoints = indices_condpoints[np.argsort(-betas[select_condpoints])] | |
# Assign points to condensation points | |
# Only assign previously unassigned points (no overwriting) | |
# Points unassigned at the end are bkg (-1) | |
unassigned = np.arange(n_points) | |
clustering = -1 * np.ones(n_points, dtype=np.int32) | |
for index_condpoint in indices_condpoints: | |
d = np.linalg.norm(X[unassigned] - X[index_condpoint], axis=-1) | |
assigned_to_this_condpoint = unassigned[d < td] | |
clustering[assigned_to_this_condpoint] = index_condpoint | |
unassigned = unassigned[~(d < td)] | |
return clustering | |
def get_clustering(betas: torch.Tensor, X: torch.Tensor, tbeta=0.1, td=1.0): | |
""" | |
Returns a clustering of hits -> cluster_index, based on the GravNet model | |
output (predicted betas and cluster space coordinates) and the clustering | |
parameters tbeta and td. | |
Takes torch.Tensors as input. | |
""" | |
n_points = betas.size(0) | |
select_condpoints = betas > tbeta | |
# Get indices passing the threshold | |
indices_condpoints = select_condpoints.nonzero() | |
# Order them by decreasing beta value | |
indices_condpoints = indices_condpoints[(-betas[select_condpoints]).argsort()] | |
# Assign points to condensation points | |
# Only assign previously unassigned points (no overwriting) | |
# Points unassigned at the end are bkg (-1) | |
unassigned = torch.arange(n_points) | |
clustering = -1 * torch.ones(n_points, dtype=torch.long) | |
for index_condpoint in indices_condpoints: | |
d = torch.norm(X[unassigned] - X[index_condpoint][0], dim=-1) | |
assigned_to_this_condpoint = unassigned[d < td] | |
clustering[assigned_to_this_condpoint] = index_condpoint[0] | |
unassigned = unassigned[~(d < td)] | |
return clustering | |
def scatter_count(input: torch.Tensor): | |
""" | |
Returns ordered counts over an index array | |
Example: | |
>>> scatter_count(torch.Tensor([0, 0, 0, 1, 1, 2, 2])) # input | |
>>> [3, 2, 2] | |
Index assumptions work like in torch_scatter, so: | |
>>> scatter_count(torch.Tensor([1, 1, 1, 2, 2, 4, 4])) | |
>>> tensor([0, 3, 2, 0, 2]) | |
""" | |
return scatter_add(torch.ones_like(input, dtype=torch.long), input.long()) | |
def scatter_counts_to_indices(input: torch.LongTensor) -> torch.LongTensor: | |
""" | |
Converts counts to indices. This is the inverse operation of scatter_count | |
Example: | |
input: [3, 2, 2] | |
output: [0, 0, 0, 1, 1, 2, 2] | |
""" | |
return torch.repeat_interleave( | |
torch.arange(input.size(0), device=input.device), input | |
).long() | |
def get_inter_event_norms_mask( | |
batch: torch.LongTensor, nclusters_per_event: torch.LongTensor | |
): | |
""" | |
Creates mask of (nhits x nclusters) that is only 1 if hit i is in the same event as cluster j | |
Example: | |
cluster_id_per_event = torch.LongTensor([0, 0, 1, 1, 2, 0, 0, 1, 1, 1, 0, 0, 1]) | |
batch = torch.LongTensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2]) | |
Should return: | |
torch.LongTensor([ | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0, 0], | |
[0, 0, 0, 0, 0, 1, 1], | |
[0, 0, 0, 0, 0, 1, 1], | |
[0, 0, 0, 0, 0, 1, 1], | |
]) | |
""" | |
device = batch.device | |
# Following the example: | |
# Expand batch to the following (nhits x nevents) matrix (little hacky, boolean mask -> long): | |
# [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]] | |
batch_expanded_as_ones = ( | |
batch | |
== torch.arange(batch.max() + 1, dtype=torch.long, device=device).unsqueeze(-1) | |
).long() | |
# Then repeat_interleave it to expand it to nclusters rows, and transpose to get (nhits x nclusters) | |
return batch_expanded_as_ones.repeat_interleave(nclusters_per_event, dim=0).T | |
def isin(ar1, ar2): | |
"""To be replaced by torch.isin for newer releases of torch""" | |
return (ar1[..., None] == ar2).any(-1) | |
def reincrementalize(y: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: | |
"""Re-indexes y so that missing clusters are no longer counted. | |
Example: | |
>>> y = torch.LongTensor([ | |
0, 0, 0, 1, 1, 3, 3, | |
0, 0, 0, 0, 0, 2, 2, 3, 3, | |
0, 0, 1, 1 | |
]) | |
>>> batch = torch.LongTensor([ | |
0, 0, 0, 0, 0, 0, 0, | |
1, 1, 1, 1, 1, 1, 1, 1, 1, | |
2, 2, 2, 2, | |
]) | |
>>> print(reincrementalize(y, batch)) | |
tensor([0, 0, 0, 1, 1, 2, 2, 0, 0, 0, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1]) | |
""" | |
y_offset, n_per_event = batch_cluster_indices(y, batch) | |
offset = y_offset - y | |
n_clusters = n_per_event.sum() | |
holes = ( | |
(~isin(torch.arange(n_clusters, device=y.device), y_offset)) | |
.nonzero() | |
.squeeze(-1) | |
) | |
n_per_event_without_holes = n_per_event.clone() | |
n_per_event_cumsum = n_per_event.cumsum(0) | |
for hole in holes.sort(descending=True).values: | |
y_offset[y_offset > hole] -= 1 | |
i_event = (hole > n_per_event_cumsum).long().argmin() | |
n_per_event_without_holes[i_event] -= 1 | |
offset_per_event = torch.zeros_like(n_per_event_without_holes) | |
offset_per_event[1:] = n_per_event_without_holes.cumsum(0)[:-1] | |
offset_without_holes = torch.gather(offset_per_event, 0, batch).long() | |
reincrementalized = y_offset - offset_without_holes | |
return reincrementalized | |
def L_clusters_calc(batch, cluster_space_coords, cluster_index, frac_combinations, q): | |
number_of_pairs = 0 | |
for batch_id in batch.unique(): | |
# do all possible pairs... | |
bmask = batch == batch_id | |
clust_space_filt = cluster_space_coords[bmask] | |
pos_pairs_all = [] | |
neg_pairs_all = [] | |
if len(cluster_index[bmask].unique()) <= 1: | |
continue | |
L_clusters = torch.tensor(0.0).to(q.device) | |
for cluster in cluster_index[bmask].unique(): | |
coords_pos = clust_space_filt[cluster_index[bmask] == cluster] | |
coords_neg = clust_space_filt[cluster_index[bmask] != cluster] | |
if len(coords_neg) == 0: | |
continue | |
clust_idx = cluster_index[bmask] == cluster | |
# all_ones = torch.ones_like((clust_idx, clust_idx)) | |
# pos_pairs = [[i, j] for i in range(len(coords_pos)) for j in range (len(coords_pos)) if i < j] | |
total_num = (len(coords_pos) ** 2) / 2 | |
num = int(frac_combinations * total_num) | |
pos_pairs = [] | |
for i in range(num): | |
pos_pairs.append( | |
[ | |
np.random.randint(len(coords_pos)), | |
np.random.randint(len(coords_pos)), | |
] | |
) | |
neg_pairs = [] | |
for i in range(len(pos_pairs)): | |
neg_pairs.append( | |
[ | |
np.random.randint(len(coords_pos)), | |
np.random.randint(len(coords_neg)), | |
] | |
) | |
pos_pairs_all += pos_pairs | |
neg_pairs_all += neg_pairs | |
pos_pairs = torch.tensor(pos_pairs_all) | |
neg_pairs = torch.tensor(neg_pairs_all) | |
"""# do just a small sample of the pairs. ... | |
bmask = batch == batch_id | |
#L_clusters = 0 # Loss of randomly sampled distances between points inside and outside clusters | |
pos_idx, neg_idx = [], [] | |
for cluster in cluster_index[bmask].unique(): | |
clust_idx = (cluster_index == cluster)[bmask] | |
perm = torch.randperm(clust_idx.sum()) | |
perm1 = torch.randperm((~clust_idx).sum()) | |
perm2 = torch.randperm(clust_idx.sum()) | |
#cutoff = clust_idx.sum()//2 | |
pos_lst = clust_idx.nonzero()[perm] | |
neg_lst = (~clust_idx).nonzero()[perm1] | |
neg_lst_second = clust_idx.nonzero()[perm2] | |
if len(pos_lst) % 2: | |
pos_lst = pos_lst[:-1] | |
if len(neg_lst) % 2: | |
neg_lst = neg_lst[:-1] | |
len_cap = min(len(pos_lst), len(neg_lst), len(neg_lst_second)) | |
if len_cap % 2: | |
len_cap -= 1 | |
pos_lst = pos_lst[:len_cap] | |
neg_lst = neg_lst[:len_cap] | |
neg_lst_second = neg_lst_second[:len_cap] | |
pos_pairs = pos_lst.reshape(-1, 2) | |
neg_pairs = torch.cat([neg_lst, neg_lst_second], dim=1) | |
neg_pairs = neg_pairs[:pos_lst.shape[0]//2, :] | |
pos_idx.append(pos_pairs) | |
neg_idx.append(neg_pairs) | |
pos_idx = torch.cat(pos_idx) | |
neg_idx = torch.cat(neg_idx)""" | |
assert pos_pairs.shape == neg_pairs.shape | |
if len(pos_pairs) == 0: | |
continue | |
cluster_space_coords_filtered = cluster_space_coords[bmask] | |
qs_filtered = q[bmask] | |
pos_norms = ( | |
cluster_space_coords_filtered[pos_pairs[:, 0]] | |
- cluster_space_coords_filtered[pos_pairs[:, 1]] | |
).norm(dim=-1) | |
neg_norms = ( | |
cluster_space_coords_filtered[neg_pairs[:, 0]] | |
- cluster_space_coords_filtered[neg_pairs[:, 1]] | |
).norm(dim=-1) | |
q_pos = qs_filtered[pos_pairs[:, 0]] | |
q_neg = qs_filtered[neg_pairs[:, 0]] | |
q_s = torch.cat([q_pos, q_neg]) | |
norms_pos = torch.cat([pos_norms, neg_norms]) | |
ys = torch.cat([torch.ones_like(pos_norms), -torch.ones_like(neg_norms)]) | |
L_clusters += torch.sum( | |
q_s * torch.nn.HingeEmbeddingLoss(reduce=None)(norms_pos, ys) | |
) | |
number_of_pairs += norms_pos.shape[0] | |
if number_of_pairs > 0: | |
L_clusters = L_clusters / number_of_pairs | |
return L_clusters | |
def calc_eta_phi(coords, return_stacked=True): | |
""" | |
Calculate eta and phi from cartesian coordinates | |
""" | |
x = coords[:, 0] | |
y = coords[:, 1] | |
z = coords[:, 2] | |
#eta, phi = torch.atan2(y, x), torch.asin(z / coords.norm(dim=1)) | |
phi = torch.arctan2(y, x) | |
eta = torch.arctanh(z / torch.sqrt(x**2 + y**2 + z**2)) | |
if not return_stacked: | |
return eta, phi | |
return torch.stack([eta, phi], dim=1) | |
def loss_func_aug(y_pred, y_pred_aug, batch, batch_aug, event, event_aug): | |
coords_pred = y_pred[:, :3] | |
coords_pred_aug = y_pred_aug[:, :3] | |
original_particle_mapping = batch_aug.original_particle_mapping | |
#print("N in batch:", event.pfcands.batch_number) | |
#print("N in batch aug:", event_aug.pfcands.batch_number) | |
to_add_to_batch = event.pfcands.batch_number[:-1] | |
aug_batch_num = event_aug.pfcands.batch_number | |
print("Original particle mapping: (before sum)", original_particle_mapping.tolist()) | |
filt_idx = torch.where(original_particle_mapping != -1)[0].tolist() | |
for i in range(len(aug_batch_num)-1): | |
for item in filt_idx: | |
if item >= aug_batch_num[i] and item < aug_batch_num[i+1]: | |
assert original_particle_mapping[item] != -1, "Original particle mapping should not be -1" | |
assert to_add_to_batch[i] >= 0, "Batch number should be >= 0: " + str(to_add_to_batch[i]) | |
original_particle_mapping[item] += to_add_to_batch[i] # Try this due to some indexing issues | |
#original_particle_mapping[aug_batch_num[i]:aug_batch_num[i+1]][filt] += to_add_to_batch[i] | |
#print("Original particle mapping:", original_particle_mapping[original_particle_mapping != -1]) | |
#original_particle_mapping[original_particle_mapping != -1] += batch_idx[original_particle_mapping != -1] | |
if not original_particle_mapping.max() < len(coords_pred): | |
print("Coords shapes", coords_pred.shape, coords_pred_aug.shape) | |
print("Original particle mapping:", original_particle_mapping[original_particle_mapping != -1], original_particle_mapping.shape, original_particle_mapping[original_particle_mapping!=-1].max()) | |
print("Batch number in event:", event.pfcands.batch_number) | |
print("Batch number in event aug:", event_aug.pfcands.batch_number) | |
print("Len batch", batch.input_vectors.shape, "len batch_aug", batch_aug.input_vectors.shape) | |
raise ValueError("Original particle mapping out of bounds") | |
assert original_particle_mapping.max() < len(coords_pred) | |
coords_pred_aug_target = coords_pred[original_particle_mapping[original_particle_mapping != -1]] | |
coords_pred_aug_output = coords_pred_aug[original_particle_mapping != -1] | |
print("Output:", coords_pred_aug_output[:5], "Target:", coords_pred_aug_target[:5]) | |
loss = torch.nn.MSELoss()(coords_pred_aug_output, coords_pred_aug_target) | |
return loss | |
def object_condensation_loss( | |
batch, # input event | |
pred, | |
labels, | |
batch_numbers, | |
q_min=3.0, | |
frac_clustering_loss=0.1, | |
attr_weight=1.0, | |
repul_weight=1.0, | |
fill_loss_weight=1.0, | |
use_average_cc_pos=0.0, | |
loss_type="hgcalimplementation", | |
clust_space_norm="none", | |
dis=False, | |
coord_weight=0.0, | |
beta_type="default", | |
lorentz_norm=False, | |
spatial_part_only=False, | |
loss_quark_distance=False, | |
oc_scalars=False, | |
loss_obj_score=False | |
): | |
""" | |
:param batch: Model input | |
:param pred: Model output, containing regressed coordinates + betas | |
:param clust_space_dim: Number of dimensions in the cluster space | |
:return: | |
""" | |
_, S = pred.shape | |
noise_logits = None | |
if beta_type == "default": | |
clust_space_dim = S - 1 | |
bj = torch.sigmoid(torch.reshape(pred[:, clust_space_dim], [-1, 1])) # betas | |
elif beta_type == "pt": | |
bj = batch.pt | |
clust_space_dim = S | |
elif beta_type == "pt+bc": | |
bj = batch.pt | |
clust_space_dim = S - 1 | |
noise_logits = pred[:, clust_space_dim] | |
original_coords = batch.input_vectors | |
if oc_scalars: | |
original_coords = original_coords[:, 1:4] | |
if dis: | |
distance_threshold = torch.reshape(pred[:, -1], [-1, 1]) | |
else: | |
distance_threshold = 0 | |
xj = pred[:, :clust_space_dim] # Coordinates in clustering space | |
#xj = calc_eta_phi(xj) | |
if clust_space_norm == "twonorm": | |
xj = torch.nn.functional.normalize(xj, dim=1) | |
elif clust_space_norm == "tanh": | |
xj = torch.tanh(xj) | |
elif clust_space_norm == "none": | |
pass | |
else: | |
raise NotImplementedError | |
if not loss_quark_distance: | |
clustering_index_l = labels | |
if loss_obj_score: | |
clustering_index_l = labels.labels+1 | |
a = calc_LV_Lbeta( | |
original_coords, | |
batch, | |
distance_threshold, | |
beta=bj.view(-1), | |
cluster_space_coords=xj, # Predicted by model | |
cluster_index_per_event=clustering_index_l.view( | |
-1 | |
).long(), # Truth hit->cluster index | |
batch=batch_numbers.long(), | |
qmin=q_min, | |
attr_weight=attr_weight, | |
repul_weight=repul_weight, | |
use_average_cc_pos=use_average_cc_pos, | |
loss_type=loss_type, | |
dis=dis, | |
beta_type=beta_type, | |
noise_logits=noise_logits, | |
lorentz_norm=lorentz_norm, | |
spatial_part_only=spatial_part_only | |
) | |
loss = a["loss_potential"] + a["loss_beta"] | |
if coord_weight > 0: | |
loss += a["loss_coord"] * coord_weight | |
else: | |
# quark distance loss | |
target_coords = labels.labels_coordinates[labels.labels[labels.labels != -1]] | |
if lorentz_norm: | |
diff = xj[labels.labels != -1] - labels.labels_coordinates[labels.labels != -1] | |
norms = diff[:, :, 0]**2 - torch.sum(diff[:, :, 1:] ** 2, dim=-1) | |
norms = norms.abs() | |
else: | |
if spatial_part_only: | |
x_coords = xj[labels.labels != -1, 1:4] | |
x_true = target_coords[:, 1:4] | |
else: | |
x_coords = xj[labels.labels != -1] | |
x_true = target_coords | |
#norms = torch.norm(x_coords - x_true, p=2, dim=1) | |
# cosine similarity | |
norms = 2 - (torch.nn.functional.cosine_similarity(x_coords, x_true[:, 1:4], dim=1) + 1) | |
a = {"norms_loss": torch.mean(norms)} | |
loss = a["norms_loss"] | |
if beta_type == "pt+bc": | |
# TODO: polish this, it's another loss that should be computed outside calc_LV_Lbeta | |
assert noise_logits is not None | |
is_noise = labels.labels == -1 | |
y_true_noise = 1 - is_noise.float() | |
num_positives = torch.sum(y_true_noise).item() | |
num_negatives = len(y_true_noise) - num_positives | |
num_all = len(y_true_noise) | |
# Compute weights | |
pos_weight = num_all / num_positives if num_positives > 0 else 0 | |
neg_weight = num_all / num_negatives if num_negatives > 0 else 0 | |
weight = pos_weight * y_true_noise + neg_weight * (1 - y_true_noise) | |
L_bc = torch.nn.BCELoss(weight=weight)( | |
noise_logits, 1 - is_noise.float() | |
) | |
a["loss_noise_classification"] = L_bc | |
if beta_type == "pt+bc": | |
loss += a["loss_noise_classification"] | |
return loss, a | |