jetclustering / src /dataset /functions_data.py
gregorkrzmanc's picture
.
e75a247
raw
history blame
44.3 kB
import numpy as np
import torch
#from torch_scatter import scatter_add, scatter_sum
def get_ratios(e_hits, part_idx, y):
"""Obtain the percentage of energy of the particle present in the hits
Args:
e_hits (_type_): _description_
part_idx (_type_): _description_
y (_type_): _description_
Returns:
_type_: _description_
"""
energy_from_showers = scatter_sum(e_hits, part_idx.long(), dim=0)
# y_energy = y[:, 3]
y_energy = y.E
energy_from_showers = energy_from_showers[1:]
assert len(energy_from_showers) > 0
return (energy_from_showers.flatten() / y_energy).tolist()
def get_number_hits(e_hits, part_idx):
number_of_hits = scatter_sum(torch.ones_like(e_hits), part_idx.long(), dim=0)
return (number_of_hits[1:].flatten()).tolist()
def get_e_reco(e_hits, part_idx):
number_of_hits = scatter_sum(e_hits, part_idx.long(), dim=0)
return number_of_hits[1:].flatten()
def get_number_of_daughters(hit_type_feature, hit_particle_link, daughters):
a = hit_particle_link
b = daughters
a_u = torch.unique(a)
number_of_p = torch.zeros_like(a_u)
for p, i in enumerate(a_u):
mask2 = a == i
number_of_p[p] = torch.sum(torch.unique(b[mask2]) != -1)
return number_of_p
def find_mask_no_energy(
hit_particle_link,
hit_type_a,
hit_energies,
y,
daughters,
predict=False,
is_Ks=False,
):
"""This function remove particles with tracks only and remove particles with low fractions
# Remove 2212 going to multiple particles without tracks for now
# remove particles below energy cut
# remove particles that decayed in the tracker
# remove particles with two tracks (due to bad tracking)
# remove particles with daughters for the moment
Args:
hit_particle_link (_type_): _description_
hit_type_a (_type_): _description_
hit_energies (_type_): _description_
y (_type_): _description_
Returns:
_type_: _description_
"""
number_of_daughters = get_number_of_daughters(
hit_type_a, hit_particle_link, daughters
)
list_p = np.unique(hit_particle_link)
list_remove = []
part_frac = torch.tensor(get_ratios(hit_energies, hit_particle_link, y))
number_of_hits = get_number_hits(hit_energies, hit_particle_link)
if predict:
energy_cut = 0.1
filt1 = (torch.where(part_frac >= energy_cut)[0] + 1).long().tolist()
else:
energy_cut = 0.01
filt1 = (torch.where(part_frac >= energy_cut)[0] + 1).long().tolist()
number_of_tracks = scatter_add(1 * (hit_type_a == 1), hit_particle_link.long())[1:]
if is_Ks == False:
for index, p in enumerate(list_p):
mask = hit_particle_link == p
hit_types = np.unique(hit_type_a[mask])
if predict:
if (
np.array_equal(hit_types, [0, 1])
or int(p) not in filt1
or (number_of_hits[index] < 2)
or (y.decayed_in_tracker[index] == 1)
or number_of_tracks[index] == 2
or number_of_daughters[index] > 1
):
list_remove.append(p)
else:
if (
np.array_equal(hit_types, [0, 1])
or int(p) not in filt1
or (number_of_hits[index] < 2)
or number_of_tracks[index] == 2
or number_of_daughters[index] > 1
):
list_remove.append(p)
if len(list_remove) > 0:
mask = torch.tensor(np.full((len(hit_particle_link)), False, dtype=bool))
for p in list_remove:
mask1 = hit_particle_link == p
mask = mask1 + mask
else:
mask = np.full((len(hit_particle_link)), False, dtype=bool)
if len(list_remove) > 0:
mask_particles = np.full((len(list_p)), False, dtype=bool)
for p in list_remove:
mask_particles1 = list_p == p
mask_particles = mask_particles1 + mask_particles
else:
mask_particles = np.full((len(list_p)), False, dtype=bool)
return mask, mask_particles
class CachedIndexList:
def __init__(self, lst):
self.lst = lst
self.cache = {}
def index(self, value):
if value in self.cache:
return self.cache[value]
else:
idx = self.lst.index(value)
self.cache[value] = idx
return idx
def find_cluster_id(hit_particle_link):
unique_list_particles = list(np.unique(hit_particle_link))
if np.sum(np.array(unique_list_particles) == -1) > 0:
non_noise_idx = torch.where(hit_particle_link != -1)[0] #
noise_idx = torch.where(hit_particle_link == -1)[0] #
unique_list_particles1 = torch.unique(hit_particle_link)[1:]
cluster_id_ = torch.searchsorted(
unique_list_particles1, hit_particle_link[non_noise_idx], right=False
)
cluster_id_small = 1.0 * cluster_id_ + 1
cluster_id = hit_particle_link.clone()
cluster_id[non_noise_idx] = cluster_id_small
cluster_id[noise_idx] = 0
else:
c_unique_list_particles = CachedIndexList(unique_list_particles)
cluster_id = map(
lambda x: c_unique_list_particles.index(x), hit_particle_link.tolist()
)
cluster_id = torch.Tensor(list(cluster_id)) + 1
return cluster_id, unique_list_particles
def scatter_count(input: torch.Tensor):
return scatter_add(torch.ones_like(input, dtype=torch.long), input.long())
def get_particle_features(unique_list_particles, output, prediction, connection_list):
unique_list_particles = torch.Tensor(unique_list_particles).to(torch.int64)
if prediction:
number_particle_features = 12 - 2
else:
number_particle_features = 9 - 2
if output["pf_features"].shape[0] == 18:
number_particle_features += 8 # add vertex information
features_particles = torch.permute(
torch.tensor(
output["pf_features"][
2:number_particle_features, list(unique_list_particles)
]
),
(1, 0),
) #
# particle_coord are just features 10, 11, 12
if features_particles.shape[1] == 16: # Using config with part_pxyz and part_vertex_xyz
#print("Using config with part_pxyz and part_vertex_xyz")
particle_coord = features_particles[:, 10:13]
vertex_coord = features_particles[:, 13:16]
# normalize particle coords
particle_coord = particle_coord# / np.linalg.norm(particle_coord, axis=1).reshape(-1, 1) # DO NOT NORMALIZE
#particle_coord, spherical_to_cartesian(
# features_particles[:, 1],
# features_particles[:, 0], # theta and phi are mixed!!!
# features_particles[:, 2],
# normalized=True,
#)
else:
particle_coord = spherical_to_cartesian(
features_particles[:, 1],
features_particles[:, 0], # theta and phi are mixed!!!
features_particles[:, 2],
normalized=True,
)
vertex_coord = torch.zeros_like(particle_coord)
y_mass = features_particles[:, 3].view(-1).unsqueeze(1)
y_mom = features_particles[:, 2].view(-1).unsqueeze(1)
y_energy = torch.sqrt(y_mass**2 + y_mom**2)
y_pid = features_particles[:, 4].view(-1).unsqueeze(1)
if prediction:
y_data_graph = Particles_GT(
particle_coord,
y_energy,
y_mom,
y_mass,
y_pid,
features_particles[:, 5].view(-1).unsqueeze(1),
features_particles[:, 6].view(-1).unsqueeze(1),
unique_list_particles=unique_list_particles,
vertex=vertex_coord,
)
else:
y_data_graph = Particles_GT(
particle_coord,
y_energy,
y_mom,
y_mass,
y_pid,
unique_list_particles=unique_list_particles,
vertex=vertex_coord,
)
return y_data_graph
def modify_index_link_for_gamma_e(
hit_type_feature, hit_particle_link, daughters, output, number_part, is_Ks=False
):
"""Split all particles that have daughters, mostly for brems and conversions but also for protons and neutrons
Returns:
hit_particle_link: new link
hit_link_modified: bool for modified hits
"""
hit_link_modified = torch.zeros_like(hit_particle_link).to(hit_particle_link.device)
mask = hit_type_feature > 1
a = hit_particle_link[mask]
b = daughters[mask]
a_u = torch.unique(a)
number_of_p = torch.zeros_like(a_u)
connections_list = []
for p, i in enumerate(a_u):
mask2 = a == i
list_of_daugthers = torch.unique(b[mask2])
number_of_p[p] = len(list_of_daugthers)
if (number_of_p[p] > 1) and (torch.sum(list_of_daugthers == i) > 0):
connections_list.append([i, torch.unique(b[mask2])])
pid_particles = torch.tensor(output["pf_features"][6, 0:number_part])
electron_photon_mask = (torch.abs(pid_particles[a_u.long()]) == 11) + (
pid_particles[a_u.long()] == 22
)
electron_photon_mask = (
electron_photon_mask * number_of_p > 1
) # electron_photon_mask *
if is_Ks:
index_change = a_u # [electron_photon_mask]
else:
index_change = a_u[electron_photon_mask]
for i in index_change:
mask_n = mask * (hit_particle_link == i)
hit_particle_link[mask_n] = daughters[mask_n]
hit_link_modified[mask_n] = 1
return hit_particle_link, hit_link_modified, connections_list
def get_hit_features(
output, number_hits, prediction, number_part, hit_chis, pos_pxpy, is_Ks=False
):
hit_particle_link = torch.tensor(output["pf_vectoronly"][0, 0:number_hits])
if prediction:
indx_daugthers = 3
else:
indx_daugthers = 1
daughters = torch.tensor(output["pf_vectoronly"][indx_daugthers, 0:number_hits])
if prediction:
pandora_cluster = torch.tensor(output["pf_vectoronly"][1, 0:number_hits])
pandora_pfo_link = torch.tensor(output["pf_vectoronly"][2, 0:number_hits])
if is_Ks:
pandora_mom = torch.permute(
torch.tensor(output["pf_points_pfo"][0:3, 0:number_hits]), (1, 0)
)
pandora_ref_point = torch.permute(
torch.tensor(output["pf_points_pfo"][3:6, 0:number_hits]), (1, 0)
)
if output["pf_points_pfo"].shape[0] > 6:
pandora_pid = torch.tensor(output["pf_points_pfo"][6, 0:number_hits])
else:
# zeros
# print("Zeros for pandora pid!")
pandora_pid=torch.zeros(number_hits)
else:
pandora_mom = None
pandora_ref_point = None
pandora_pid = None
if is_Ks:
pandora_cluster_energy = torch.tensor(
output["pf_features"][9, 0:number_hits]
)
pfo_energy = torch.tensor(output["pf_features"][10, 0:number_hits])
chi_squared_tracks = torch.tensor(output["pf_features"][11, 0:number_hits])
elif hit_chis:
pandora_cluster_energy = torch.tensor(
output["pf_features"][-3, 0:number_hits]
)
pfo_energy = torch.tensor(output["pf_features"][-2, 0:number_hits])
chi_squared_tracks = torch.tensor(output["pf_features"][-1, 0:number_hits])
else:
pandora_cluster_energy = torch.tensor(
output["pf_features"][-2, 0:number_hits]
)
pfo_energy = torch.tensor(output["pf_features"][-1, 0:number_hits])
chi_squared_tracks = None
else:
pandora_cluster = None
pandora_pfo_link = None
pandora_cluster_energy = None
pfo_energy = None
chi_squared_tracks = None
pandora_mom = None
pandora_ref_point = None
pandora_pid = None
# hit type
hit_type_feature = torch.permute(
torch.tensor(output["pf_vectors"][:, 0:number_hits]), (1, 0)
)[:, 0].to(torch.int64)
(
hit_particle_link,
hit_link_modified,
connection_list,
) = modify_index_link_for_gamma_e(
hit_type_feature, hit_particle_link, daughters, output, number_part, is_Ks
)
cluster_id, unique_list_particles = find_cluster_id(hit_particle_link)
# position, e, p
pos_xyz_hits = torch.permute(
torch.tensor(output["pf_points"][0:3, 0:number_hits]), (1, 0)
)
pf_features_hits = torch.permute(
torch.tensor(output["pf_features"][0:2, 0:number_hits]), (1, 0)
) # removed theta, phi
p_hits = pf_features_hits[:, 0].unsqueeze(1)
p_hits[p_hits == -1] = 0 # correct p of Hcal hits to be 0
e_hits = pf_features_hits[:, 1].unsqueeze(1)
e_hits[e_hits == -1] = 0 # correct the energy of the tracks to be 0
if pos_pxpy:
pos_pxpypz = torch.permute(
torch.tensor(output["pf_points"][3:, 0:number_hits]), (1, 0)
)
else:
pos_pxpypz = pos_xyz_hits
# pos_pxpypz = pos_theta_phi
return (
pos_xyz_hits,
pos_pxpypz,
p_hits,
e_hits,
hit_particle_link,
pandora_cluster,
pandora_cluster_energy,
pfo_energy,
pandora_mom,
pandora_ref_point,
pandora_pid,
unique_list_particles,
cluster_id,
hit_type_feature,
pandora_pfo_link,
daughters,
hit_link_modified,
connection_list,
chi_squared_tracks,
)
def standardize_coordinates(coord_cart_hits):
if len(coord_cart_hits) == 0:
return coord_cart_hits, None
std_scaler = StandardScaler()
coord_cart_hits = std_scaler.fit_transform(coord_cart_hits)
return torch.tensor(coord_cart_hits).float(), std_scaler
def create_dif_interactions(i, j, pos, number_p):
x_interactions = pos
x_interactions = torch.reshape(x_interactions, [number_p, 1, 2])
x_interactions = x_interactions.repeat(1, number_p, 1)
xi = x_interactions[i, j, :]
xj = x_interactions[j, i, :]
x_interactions_m = xi - xj
return x_interactions_m
def spherical_to_cartesian(phi, theta, r, normalized=False):
if normalized:
r = torch.ones_like(phi)
x = r * torch.sin(theta) * torch.cos(phi)
y = r * torch.sin(theta) * torch.sin(phi)
z = r * torch.cos(theta)
return torch.cat((x.unsqueeze(1), y.unsqueeze(1), z.unsqueeze(1)), dim=1)
def calculate_distance_to_boundary(g):
r = 2150
r_in_endcap = 2307
mask_endcap = (torch.abs(g.ndata["pos_hits_xyz"][:, 2]) - r_in_endcap) > 0
mask_barrer = ~mask_endcap
weight = torch.ones_like(g.ndata["pos_hits_xyz"][:, 0])
C = g.ndata["pos_hits_xyz"]
A = torch.Tensor([0, 0, 1]).to(C.device)
P = (
r
* 1
/ (torch.norm(torch.cross(A.view(1, -1), C, dim=-1), dim=1)).unsqueeze(1)
* C
)
P1 = torch.abs(r_in_endcap / g.ndata["pos_hits_xyz"][:, 2].unsqueeze(1)) * C
weight[mask_barrer] = torch.norm(P - C, dim=1)[mask_barrer]
weight[mask_endcap] = torch.norm(P1[mask_endcap] - C[mask_endcap], dim=1)
g.ndata["radial_distance"] = weight
weight_ = torch.exp(-(weight / 1000))
g.ndata["radial_distance_exp"] = weight_
return g
class EventCollection:
def mask(self, mask):
for k in self.__dict__:
if getattr(self, k) is not None:
if type(getattr(self, k)) == list:
if getattr(self, k)[0] is not None:
setattr(self, k, getattr(self, k)[mask])
elif not type(getattr(self, k)) == dict:
setattr(self, k, getattr(self, k)[mask])
else:
raise NotImplementedError("Need to implement correct indexing")
# TODO: for the mapping pfcands_idx to jet_idx
def copy(self):
obj = type(self).__new__(self.__class__)
obj.__dict__.update(self.__dict__)
return obj
def serialize(self):
# get all the self.init_attrs and concat them together. Also return batch_number
res = []
for attr in self.init_attrs:
if attr == "status" and not hasattr(self, attr):
continue
res.append(getattr(self, attr))
data = torch.stack(res).T
#data = torch.stack([getattr(self, attr) for attr in self.init_attrs]).T
assert data.shape[0] == self.batch_number.max().item()
return data, self.batch_number
def __getitem__(self, i):
data = {}
s, e = self.batch_number[i], self.batch_number[i + 1]
for attr in type(self).init_attrs:
if attr == "status" and not hasattr(self, attr):
continue
data[attr] = getattr(self, attr)[s:e]
return type(self)(**data)
@staticmethod
def deserialize(data_matrix, batch_number, cls):
data = {}
filt = None
for i, key in enumerate(cls.init_attrs):
if i >= data_matrix.shape[1]:
break # For some PFCands, 'status' is not populated
data[key] = data_matrix[:, i]
#if key == "pid" and pid_filter:
# filt = ~np.bool(np.abs(data[key]) >= 10000 + (np.abs(data[key]) >= 50 * np.abs(data[key]) <= 60))
return cls(**data, batch_number=batch_number)
def concat_event_collection(list_event_collection, nobatch=False):
c = list_event_collection[0]
list_of_attrs = c.init_attrs
#for k in c.__dict__:
# if getattr(c, k) is not None:
# if isinstance(getattr(c, k), torch.Tensor):
# list_of_attrs.append(k)
result = {}
for attr in list_of_attrs:
if hasattr(c, attr):
result[attr] = torch.cat([getattr(c, attr) for c in list_event_collection], dim=0)
if hasattr(c, "original_particle_mapping") and c.original_particle_mapping is not None:
result["original_particle_mapping"] = torch.cat([c.original_particle_mapping for c in list_event_collection], dim=0)
if not nobatch:
batch_number, to_add_idx = add_batch_number(list_event_collection, attr=list_of_attrs[0])
#if hasattr(c, "original_particle_mapping") and c.original_particle_mapping is not None:
# #filt = result["original_particle_mapping"] != -1
# #result["original_particle_mapping"][filt] += to_add_idx[filt]
return type(c)(**result, batch_number=batch_number)
else:
return type(c)(**result)
def concat_events(list_events):
attrs = list_events[0].init_attrs
result = {}
for attr in attrs:
result[attr] = concat_event_collection([getattr(e, attr) for e in list_events])
# assert result[attr].batch_number.max() == len(list_events)# sometimes the event is empty (e.g. no found jets)
return Event(**result, n_events=len(list_events))
def renumber_clusters(tensor):
unique = tensor.unique()
mapping = torch.zeros(unique.max().int().item() + 1)
for i, u in enumerate(unique):
mapping[u] = i
return mapping[tensor]
class TensorCollection:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def to(self, device):
# Move all tensors to device
for k, v in self.__dict__.items():
if torch.is_tensor(v):
setattr(self, k, v.to(device))
return self
def dict_rep(self):
d = {}
for k, v in self.__dict__.items():
if torch.is_tensor(v):
d[k] = v
return d
#def __getitem__(self, i):
# return TensorCollection(**{k: v[i] for k, v in self.__dict__.items()})
def get_corrected_batch(event_batch, cluster_idx, test):
# return a batch with fake nodes in it (as .fake_nodes_idx property) and cluster_idx should be set to -1 for the nodes that don't belong anywhere
# cluster_idx should be a tensor of the same length as the input vectors
clusters = torch.where(torch.tensor(cluster_idx) != -1)[0]
new_batch_idx = torch.tensor(cluster_idx[clusters])
# for each cluster, add a fake node that has zeros for vectors, scalars and pt
batch_idx_fake_nodes = torch.sort(new_batch_idx.unique())[0]
vectors_fake_nodes = torch.zeros(len(batch_idx_fake_nodes), event_batch.input_vectors.shape[1])
vectors_fake_nodes = vectors_fake_nodes.to(event_batch.input_vectors.device)
scalars_fake_nodes = torch.zeros(len(batch_idx_fake_nodes), event_batch.input_scalars.shape[1])
scalars_fake_nodes = scalars_fake_nodes.to(event_batch.input_scalars.device)
pt_fake_nodes = torch.zeros(len(batch_idx_fake_nodes))
pt_fake_nodes = pt_fake_nodes.to(event_batch.pt.device)
#event_batch.input_vectors[clusters]
#event_batch.input_scalars[clusters]
#event_batch.pt[clusters]
#
input_vectors = torch.cat([event_batch.input_vectors[clusters], vectors_fake_nodes], dim=0)
input_scalars = torch.cat([event_batch.input_scalars[clusters], scalars_fake_nodes], dim=0)
pt = torch.cat([event_batch.pt[clusters], pt_fake_nodes], dim=0)
batch_idx = torch.cat([new_batch_idx, batch_idx_fake_nodes], dim=0)
batch_sort_idx = torch.argsort(batch_idx) # the models need batch idx in ascending order in order to correctly construct the attention mask
#return EventBatch(
# input_vectors=input_vectors[batch_sort_idx],
# input_scalars=input_scalars[batch_sort_idx],
# pt=pt[batch_sort_idx],
# batch_idx=batch_idx[batch_sort_idx],
# fake_nodes_idx=batch_idx_fake_nodes + len(new_batch_idx),
#)
#For returning without the fake nodes (!!!!!)
#print("New batch idx", renumber_clusters(new_batch_idx))
return EventBatch(
input_vectors=event_batch.input_vectors[clusters],
input_scalars=event_batch.input_scalars[clusters],
pt=event_batch.pt[clusters],
batch_idx=new_batch_idx,
renumber_clusters=not test
)
def get_batch(event, batch_config, y, test=False, external_batch_filter=None):
# Returns the EventBatch class, with correct scalars etc.
# If test=True, it will put all events in the batch, i.e. no filtering of the events without signal.
pfcands = event.pfcands
if batch_config.get("parton_level", False):
pfcands = event.final_parton_level_particles
if batch_config.get("gen_level", False):
pfcands = event.final_gen_particles
batch_idx_pfcands = torch.zeros(len(pfcands)).long()
#batch_idx_special_pfcands = torch.zeros(len(event.special_pfcands)).long()
for i in range(len(pfcands.batch_number) - 1):
batch_idx_pfcands[pfcands.batch_number[i]:pfcands.batch_number[i+1]] = i
batch_filter = []
if batch_config.get("quark_dist_loss", False):
lbl = y.labels
elif batch_config.get("obj_score", False):
lbl = y.labels
dq_coords = y.dq_coords
dq_coords_batch_idx = y.dq_coords_batch_idx
else:
lbl = y
if not (test or batch_config.get("quark_dist_loss", False)): # dont filter for quark distance loss
for i in batch_idx_pfcands.unique().tolist():
if (lbl[batch_idx_pfcands == i] == -1).all():
batch_filter.append(i)
#for i in range(len(event.special_pfcands.batch_number) - 1):
# batch_idx_special_pfcands[event.special_pfcands.batch_number[i]:event.special_pfcands.batch_number[i+1]] = i
#batch_idx = torch.cat([batch_idx_pfcands, batch_idx_special_pfcands])
batch_idx = batch_idx_pfcands
batch_idx = batch_idx.to(pfcands.pt.device)
if batch_config.get("use_p_xyz", False):
#batch_vectors = torch.cat([event.pfcands.pxyz, event.special_pfcands.pxyz], dim=0)
batch_vectors = pfcands.pxyz
elif batch_config.get("use_four_momenta", False):
batch_vectors = torch.cat([pfcands.E.unsqueeze(-1), pfcands.pxyz], dim=1)
assert batch_vectors.shape[0] == pfcands.E.shape[0]
else:
raise NotImplementedError
chg = pfcands.charge.unsqueeze(1)
if batch_config.get("no_pid", False):
batch_scalars_pfcands = chg
else:
pids = batch_config.get("pids", [11, 13, 22, 130, 211, 0, 1, 2, 3]) # 0, 1, 2, 3 are the special PFcands
# onehot encode pids of event.pfcands.pid
pids_onehot = torch.zeros(len(pfcands), len(pids))
for i in pfcands.pid:
if abs(i).item() not in pids:
print(i, "not in", pids)
raise Exception
for i, pid in enumerate(pids):
pids_onehot[:, i] = (pfcands.pid.abs() == pid).float()
assert (pids_onehot.sum(dim=1) == 1).all()
batch_scalars_pfcands = torch.cat([chg, pids_onehot], dim=1)
#if batch_config.get("use_p_xyz", False):
# # also add pt as a scalar
batch_scalars_pfcands = torch.cat([batch_scalars_pfcands, pfcands.pt.unsqueeze(1), pfcands.E.unsqueeze(1)], dim=1)
#pids_onehot_special_pfcands = torch.zeros(len(event.special_pfcands), len(pids))
#for i, pid in enumerate(pids):
# pids_onehot_special_pfcands[:, i] = (event.special_pfcands.pid.abs() == pid).float()
#assert (pids_onehot_special_pfcands.sum(dim=1) == 1).all()
#batch_scalars_special_pfcands =event.special_pfcands.charge.unsqueeze(1) #torch.cat([event.special_pfcands.charge.unsqueeze(1), pids_onehot_special_pfcands], dim=1)
batch_scalars = batch_scalars_pfcands # torch.cat([batch_scalars_pfcands, batch_scalars_special_pfcands], dim=0)
if batch_idx.max() != event.n_events - 1:
print("Error!!")
print("Batch idx", batch_idx.max(), batch_idx.tolist())
print("N events", event.n_events)
print("Batch number:", pfcands.batch_number)
#assert batch_idx.max() == event.n_events - 1
filt = ~torch.isin(batch_idx_pfcands, torch.tensor(batch_filter))
if batch_config.get("obj_score", False):
filt_dq = ~torch.isin(dq_coords_batch_idx, torch.tensor(batch_filter))
dropped_batches = batch_idx[~filt].unique()
#if (~filt).sum() > 0:
# #print("Found events with no signal!!! Dropping it in training", (~filt).sum() / len(filt), batch_filter)
# #print("Renumbered", renumber_clusters(batch_idx[filt]).unique())
# #print("Original", batch_idx[filt].unique())
# #print("ALL", batch_idx.unique())
if batch_config.get("quark_dist_loss", False):
y_filt = y
elif batch_config.get("obj_score", False):
#print(dq_coords[0].shape, filt_dq.shape, lbl.shape, filt.shape, dq_coords[1].shape)
#print(dq_coords_batch_idx[filt_dq])
y_filt = TensorCollection(labels=lbl[filt], dq_eta=dq_coords[0][filt_dq], dq_phi=dq_coords[1][filt_dq],
dq_coords_batch_idx=renumber_clusters(dq_coords_batch_idx[filt_dq].int()))
else:
y_filt = y[filt]
#print("Filtering y!" , len(y[filt]), len(batch_vectors[filt]))
print("------- Dropped batches:", dropped_batches)
if pfcands.original_particle_mapping is not None:
opm = pfcands.original_particle_mapping[filt]
else: opm = None
return EventBatch(
input_vectors=batch_vectors[filt],
input_scalars=batch_scalars[filt],
batch_idx=batch_idx[filt],
pt=pfcands.pt[filt],
filter=filt,
dropped_batches=dropped_batches,
renumber=not test,
original_particle_mapping=opm
), y_filt
def to_tensor(item):
if isinstance(item, torch.Tensor):
# if it's float, change to double
if item.dtype == torch.float32:
return item.double()
return item
item = torch.tensor(item)
if item.dtype == torch.float32:
return item.double()
return item
class EventPFCands(EventCollection):
init_attrs = ["pt", "eta", "phi", "mass", "charge", "pid", "pf_cand_jet_idx", "status"]
def __init__(
self,
pt,
eta,
phi,
mass,
charge,
pid,
jet_idx=None,
pfcands_idx=None,
batch_number=None,
offline=False,
pf_cand_jet_idx=None, # Optional: provide either this or pfcands_idx & jet_idx
status=None, # optional
pid_filter=True, # if true, remove invisible GenParticles (abs(pid) > 10000 or (pid >= 50 and pid <= 60)
original_particle_mapping=None
):
#print("Jet idx:", jet_idx)
#print("PFCands_idx:", pfcands_idx)
self.pt = to_tensor(pt)
self.eta = to_tensor(eta)
self.theta = 2 * torch.atan(torch.exp(-self.eta))
self.p = self.pt / torch.sin(self.theta)
self.phi = to_tensor(phi)
self.pxyz = torch.stack(
(self.p * torch.cos(self.phi) * torch.sin(self.theta),
self.p * torch.sin(self.phi) * torch.sin(self.theta),
self.p * torch.cos(self.theta)),
dim=1
)
#assert (torch.abs(torch.norm(self.pxyz, dim=1) - self.p) < 0.1).all(), (torch.abs(torch.norm(self.pxyz, dim=1) - self.p).max())
if not (torch.abs(torch.norm(self.pxyz, dim=1) - self.p) < 0.05).all():
print("!!!!!", (torch.abs(torch.norm(self.pxyz, dim=1) - self.p)).max())
# argmax
am = torch.argmax(torch.abs(torch.norm(self.pxyz, dim=1) - self.p))
print("pt", self.pt[am], "eta", self.eta[am], "phi", self.phi[am], "mass", mass[am], "batch_number", batch_number)
#print("pt", self.pt, "eta", self.eta, "phi", self.phi, "mass", mass, "batch_number", batch_number)
self.mass = to_tensor(mass)
self.E = torch.sqrt(self.mass ** 2 + self.p ** 2)
self.charge = to_tensor(charge)
self.pid = to_tensor(pid)
if original_particle_mapping is not None:
self.original_particle_mapping = to_tensor(original_particle_mapping)
else:
self.original_particle_mapping = original_particle_mapping
if status is not None:
self.status = to_tensor(status)
#self.init_attrs.append("status")
if pf_cand_jet_idx is not None:
self.pf_cand_jet_idx = to_tensor(pf_cand_jet_idx)
else:
self.pf_cand_jet_idx = torch.ones(len(self.pt)).int() * -1
for i, pfcand_idx in enumerate(pfcands_idx):
if int(pfcand_idx) >= len(self.pt):
print("Out of bounds")
if not offline:
raise Exception
else:
self.pf_cand_jet_idx[int(pfcand_idx)] = int(jet_idx[i])
if batch_number is not None:
self.batch_number = batch_number
def __len__(self):
return len(self.pt)
class EventMetadataAndMET(EventCollection):
# Extra info belonging to the event: MET, trigger info etc.
init_attrs = ["pt", "phi", "scouting_trig", "offline_trig", "veto_trig"]
def __init__(self, pt, phi, scouting_trig, offline_trig, veto_trig, batch_number=None):
self.pt = to_tensor(pt)
self.phi = to_tensor(phi)
self.scouting_trig = to_tensor(scouting_trig)
self.offline_trig = to_tensor(offline_trig)
self.veto_trig = to_tensor(veto_trig)
if batch_number is not None:
self.batch_number = to_tensor(batch_number)
def __len__(self):
return len(self.pt)
class EventJets(EventCollection):
init_attrs = ["pt", "eta", "phi", "mass"]
def __init__(
self,
pt,
eta,
phi,
mass,
area=None,
obj_score=None,
target_obj_score=None,
batch_number=None
):
self.pt = to_tensor(pt)
self.eta = to_tensor(eta)
self.theta = 2 * torch.atan(torch.exp(-self.eta))
self.p = pt / torch.sin(self.theta)
self.phi = to_tensor(phi)
self.pxyz = torch.stack(
(self.p * torch.cos(self.phi) * torch.sin(self.theta),
self.p * torch.sin(self.phi) * torch.sin(self.theta),
self.p * torch.cos(self.theta)),
dim=1
)
if obj_score is not None:
self.obj_score = to_tensor(obj_score)
if target_obj_score is not None:
self.target_obj_score = to_tensor(target_obj_score)
tst = torch.abs(torch.norm(self.pxyz, dim=1) - self.p)
#if not (tst[~torch.isnan(tst)] < 1e-2).all():
# print("!!!!!", (torch.abs(torch.norm(self.pxyz, dim=1) - self.p)).max())
# print("pt", self.pt, "eta", self.eta, "phi", self.phi, "mass", mass, "batch_number", batch_number)
# assert False
self.mass = to_tensor(mass)
self.area = area
self.E = torch.sqrt(self.mass ** 2 + self.p ** 2)
if self.area is not None:
self.area = to_tensor(self.area)
if batch_number is not None:
self.batch_number = to_tensor(batch_number)
def __len__(self):
return len(self.pt)
class Particles_GT:
def __init__(
self,
coordinates,
energy,
momentum,
mass,
pid,
decayed_in_calo=None,
decayed_in_tracker=None,
batch_number=None,
unique_list_particles=None,
energy_corrected=None,
vertex=None,
):
self.coord = coordinates
self.E = energy
self.E_corrected = energy
if energy_corrected is not None:
self.E_corrected = energy_corrected
assert len(coordinates) == len(energy)
self.m = momentum
self.mass = mass
self.pid = pid
self.vertex = vertex
if unique_list_particles is not None:
self.unique_list_particles = unique_list_particles
if decayed_in_calo is not None:
self.decayed_in_calo = decayed_in_calo
if decayed_in_tracker is not None:
self.decayed_in_tracker = decayed_in_tracker
if batch_number is not None:
self.batch_number = batch_number
def __len__(self):
return len(self.E)
def mask(self, mask):
for k in self.__dict__:
if getattr(self, k) is not None:
if type(getattr(self, k)) == list:
if getattr(self, k)[0] is not None:
setattr(self, k, getattr(self, k)[mask])
else:
setattr(self, k, getattr(self, k)[mask])
def copy(self):
obj = type(self).__new__(self.__class__)
obj.__dict__.update(self.__dict__)
return obj
def calculate_corrected_E(self, g, connections_list):
for element in connections_list:
# checked there is track
parent_particle = element[0]
mask_i = g.ndata["particle_number_nomap"] == parent_particle
track_number = torch.sum(g.ndata["hit_type"][mask_i] == 1)
if track_number > 0:
# find index in list
index_parent = torch.argmax(
1 * (self.unique_list_particles == parent_particle)
)
energy_daugthers = 0
for daugther in element[1]:
if daugther != parent_particle:
if torch.sum(self.unique_list_particles == daugther) > 0:
index_daugthers = torch.argmax(
1 * (self.unique_list_particles == daugther)
)
energy_daugthers = (
self.E[index_daugthers] + energy_daugthers
)
self.E_corrected[index_parent] = (
self.E_corrected[index_parent] - energy_daugthers
)
self.coord[index_parent] *= (1 - energy_daugthers / torch.norm(self.coord[index_parent]))
def concatenate_Particles_GT(list_of_Particles_GT):
list_coord = [p[1].coord for p in list_of_Particles_GT]
list_vertex = [p[1].vertex for p in list_of_Particles_GT]
list_coord = torch.cat(list_coord, dim=0)
list_E = [p[1].E for p in list_of_Particles_GT]
list_E = torch.cat(list_E, dim=0)
list_E_corr = [p[1].E_corrected for p in list_of_Particles_GT]
list_E_corr = torch.cat(list_E_corr, dim=0)
list_m = [p[1].m for p in list_of_Particles_GT]
list_m = torch.cat(list_m, dim=0)
list_mass = [p[1].mass for p in list_of_Particles_GT]
list_mass = torch.cat(list_mass, dim=0)
list_pid = [p[1].pid for p in list_of_Particles_GT]
list_pid = torch.cat(list_pid, dim=0)
if list_vertex[0] is not None:
list_vertex = torch.cat(list_vertex, dim=0)
if hasattr(list_of_Particles_GT[0], "decayed_in_calo"):
list_dec_calo = [p[1].decayed_in_calo for p in list_of_Particles_GT]
list_dec_track = [p[1].decayed_in_tracker for p in list_of_Particles_GT]
list_dec_calo = torch.cat(list_dec_calo, dim=0)
list_dec_track = torch.cat(list_dec_track, dim=0)
else:
list_dec_calo = None
list_dec_track = None
batch_number = add_batch_number(list_of_Particles_GT)
return Particles_GT(
list_coord,
list_E,
list_m,
list_mass,
list_pid,
list_dec_calo,
list_dec_track,
batch_number,
energy_corrected=list_E_corr,
vertex=list_vertex,
)
def add_batch_number(list_event_collections, attr):
list_y = []
list_y_to_add = [] # Computes a list of numbers to add to the original_particle_idx or similar fields
idx = 0
list_y.append(idx)
for i, el in enumerate(list_event_collections):
num_in_batch = el.__dict__[attr].shape[0]
list_y.append(idx + num_in_batch)
list_y_to_add += [idx] * num_in_batch
idx += num_in_batch
list_y = torch.tensor(list_y)
return list_y, torch.tensor(list_y_to_add)
def create_noise_label(hit_energies, hit_particle_link, y, cluster_id):
unique_p_numbers = torch.unique(cluster_id)
number_of_hits = get_number_hits(hit_energies, cluster_id)
e_reco = get_e_reco(hit_energies, cluster_id)
mask_hits = to_tensor(number_of_hits) < 6
mask_p = e_reco<0.10
mask_all = mask_hits.view(-1) + mask_p.view(-1)
list_remove = unique_p_numbers[mask_all.view(-1)]
if len(list_remove) > 0:
mask = to_tensor(np.full((len(cluster_id)), False, dtype=bool))
for p in list_remove:
mask1 = cluster_id == p
mask = mask1 + mask
else:
mask = to_tensor(np.full((len(cluster_id)), False, dtype=bool))
list_p = unique_p_numbers
if len(list_remove) > 0:
mask_particles = np.full((len(list_p)), False, dtype=bool)
for p in list_remove:
mask_particles1 = list_p == p
mask_particles = mask_particles1 + mask_particles
else:
mask_particles = to_tensor(np.full((len(list_p)), False, dtype=bool))
return mask.to(bool), ~mask_particles.to(bool)
class EventBatch:
def __init__(self, input_vectors, input_scalars, batch_idx, pt, original_particle_mapping=None, filter=None, dropped_batches=None, fake_nodes_idx=None, batch_idx_events=None, renumber=False):
self.input_vectors = input_vectors
self.input_scalars = input_scalars
self.batch_idx = batch_idx #renumber_clusters(batch_idx)
if renumber:
self.batch_idx = renumber_clusters(batch_idx)
self.pt = pt
self.filter = filter
self.dropped_batches = dropped_batches
self.original_particle_mapping = original_particle_mapping
if fake_nodes_idx is not None:
self.fake_nodes_idx = fake_nodes_idx
if batch_idx_events is not None:
self.batch_idx_events = batch_idx_events # Used for
def to(self, device):
self.input_vectors = self.input_vectors.to(device)
self.input_scalars = self.input_scalars.to(device)
self.batch_idx = self.batch_idx.to(device)
self.pt = self.pt.to(device)
if self.filter is not None:
self.filter = self.filter.to(device)
if self.original_particle_mapping is not None:
self.original_particle_mapping = self.original_particle_mapping.to(device)
return self
def cpu(self):
return self.to(torch.device("cpu"))
class Event:
evt_collections = {"jets": EventJets, "genjets": EventJets, "pfcands": EventPFCands,
"offline_pfcands": EventPFCands, "MET": EventMetadataAndMET, "fatjets": EventJets,
"special_pfcands": EventPFCands, "matrix_element_gen_particles": EventPFCands,
"model_jets": EventJets, "final_gen_particles": EventPFCands,
"final_parton_level_particles": EventPFCands}
def __init__(self, jets=None, genjets=None, pfcands=None, offline_pfcands=None, MET=None, fatjets=None,
special_pfcands=None, matrix_element_gen_particles=None, model_jets=None, model_jets_unfiltered=None,
n_events=1, fastjet_jets=None, final_gen_particles=None, final_parton_level_particles=None):
self.jets = jets
self.genjets = genjets
self.pfcands = pfcands
self.offline_pfcands = offline_pfcands
self.MET = MET
self.fatjets = fatjets
self.fastjet_jets = fastjet_jets
self.special_pfcands = special_pfcands
self.matrix_element_gen_particles = matrix_element_gen_particles
self.model_jets = model_jets
self.model_jets_unfiltered = model_jets_unfiltered
self.init_attrs = []
self.n_events = n_events
self.final_gen_particles = final_gen_particles
self.final_parton_level_particles = final_parton_level_particles
if jets is not None:
self.init_attrs.append("jets")
if genjets is not None:
self.init_attrs.append("genjets")
if pfcands is not None:
self.init_attrs.append("pfcands")
if offline_pfcands is not None:
self.init_attrs.append("offline_pfcands")
if MET is not None:
self.init_attrs.append("MET")
if fatjets is not None:
self.init_attrs.append("fatjets")
if special_pfcands is not None:
self.init_attrs.append("special_pfcands")
if matrix_element_gen_particles is not None:
self.init_attrs.append("matrix_element_gen_particles")
if model_jets is not None:
self.init_attrs.append("model_jets")
if model_jets_unfiltered is not None:
self.init_attrs.append("model_jets_unfiltered")
if final_gen_particles is not None:
self.init_attrs.append("final_gen_particles")
if final_parton_level_particles is not None:
self.init_attrs.append("final_parton_level_particles")
#if fastjet_jets is not None:
# self.init_attrs.append("fastjet_jets")
''' @staticmethod
def deserialize(result, result_metadata, event_idx=None):
# 'result' arrays can be mmap-ed.
# If event_idx is not None and is set to a list, only the selected event_idx will be returned.
n_events = result_metadata["n_events"]
attrs = result.keys()
if event_idx is None:
event_idx = to_tensor(list(range(n_events)))
else:
event_idx = to_tensor(event_idx)
assert (event_idx < n_events).all()
return Event(**{key: result[key][torch.isin(result_metadata[key + "_batch_idx"], event_idx)] for key in attrs}, n_events=n_events)
'''
def __len__(self):
return self.n_events
def serialize(self):
result = {}
result_metadata = {"n_events": self.n_events, "attrs": self.init_attrs}
for key in self.init_attrs:
s = getattr(self, key).serialize()
result[key] = s[0]
result_metadata[key + "_batch_idx"] = s[1]
return result, result_metadata
def __getitem__(self, i):
dic = {}
for key in self.init_attrs:
#s, e = getattr(self, key).batch_number[i], getattr(self, key).batch_number[i + 1]
dic[key] = getattr(self, key)[i]
return Event(**dic, n_events=1)