import numpy as np import torch #from torch_scatter import scatter_add, scatter_sum, scatter_mean from src.dataset.functions_data import ( get_ratios, find_mask_no_energy, find_cluster_id, get_particle_features, get_hit_features, calculate_distance_to_boundary, concatenate_Particles_GT, create_noise_label, EventJets, EventPFCands, EventCollection, Event, EventMetadataAndMET, concat_event_collection ) def create_inputs_from_table( output, hits_only, prediction=False, hit_chis=False, pos_pxpy=False, is_Ks=False ): """Used by graph creation to get nodes and edge features Args: output (_type_): input from the root reading hits_only (_type_): reading only hits or also tracks prediction (bool, optional): if running in eval mode. Defaults to False. Returns: _type_: all information to construct a graph """ graph_empty = False number_hits = np.int32(np.sum(output["pf_mask"][0])) number_part = np.int32(np.sum(output["pf_mask"][1])) ( pos_xyz_hits, pos_pxpypz, p_hits, e_hits, hit_particle_link, pandora_cluster, pandora_cluster_energy, pfo_energy, pandora_mom, pandora_ref_point, pandora_pid, unique_list_particles, cluster_id, hit_type_feature, pandora_pfo_link, daughters, hit_link_modified, connection_list, chi_squared_tracks, ) = get_hit_features( output, number_hits, prediction, number_part, hit_chis=hit_chis, pos_pxpy=pos_pxpy, is_Ks=is_Ks, ) # features particles if torch.sum(torch.Tensor(unique_list_particles)>20000)>0: graph_empty = True else: y_data_graph = get_particle_features( unique_list_particles, output, prediction, connection_list ) assert len(y_data_graph) == len(unique_list_particles) # remove particles that have no energy, no hits or only track hits if not is_Ks: mask_hits, mask_particles = find_mask_no_energy( cluster_id, hit_type_feature, e_hits, y_data_graph, daughters, prediction, is_Ks=is_Ks, ) # create mapping from links to number of particles in the event cluster_id, unique_list_particles = find_cluster_id(hit_particle_link[~mask_hits]) y_data_graph.mask(~mask_particles) else: mask_hits = torch.zeros_like(e_hits).bool().view(-1) if prediction: if is_Ks: result = [ y_data_graph, # y_data_graph[~mask_particles], p_hits[~mask_hits], e_hits[~mask_hits], cluster_id, hit_particle_link[~mask_hits], pos_xyz_hits[~mask_hits], pos_pxpypz[~mask_hits], pandora_cluster[~mask_hits], pandora_cluster_energy[~mask_hits], pandora_mom[~mask_hits], pandora_ref_point[~mask_hits], pandora_pid[~mask_hits], pfo_energy[~mask_hits], pandora_pfo_link[~mask_hits], hit_type_feature[~mask_hits], hit_link_modified[~mask_hits], daughters[~mask_hits], ] else: result = [ y_data_graph, # y_data_graph[~mask_particles], p_hits[~mask_hits], e_hits[~mask_hits], cluster_id, hit_particle_link[~mask_hits], pos_xyz_hits[~mask_hits], pos_pxpypz[~mask_hits], pandora_cluster[~mask_hits], pandora_cluster_energy[~mask_hits], pandora_mom, pandora_ref_point, pandora_pid, pfo_energy[~mask_hits], pandora_pfo_link[~mask_hits], hit_type_feature[~mask_hits], hit_link_modified[~mask_hits], ] else: result = [ y_data_graph, # y_data_graph[~mask_particles], p_hits[~mask_hits], e_hits[~mask_hits], cluster_id, hit_particle_link[~mask_hits], pos_xyz_hits[~mask_hits], pos_pxpypz[~mask_hits], pandora_cluster, pandora_cluster_energy, pandora_mom, pandora_ref_point, pandora_pid, pfo_energy, pandora_pfo_link, hit_type_feature[~mask_hits], hit_link_modified[~mask_hits], ] if hit_chis: result.append( chi_squared_tracks[~mask_hits], ) else: result.append(None) hit_type = hit_type_feature[~mask_hits] # if hits only remove tracks, otherwise leave tracks if hits_only: hit_mask = (hit_type == 0) | (hit_type == 1) hit_mask = ~hit_mask for i in range(1, len(result)): if result[i] is not None: result[i] = result[i][hit_mask] hit_type_one_hot = torch.nn.functional.one_hot( hit_type_feature[~mask_hits][hit_mask] - 2, num_classes=2 ) else: # if we want the tracks keep only 1 track hit per charged particle. hit_mask = hit_type == 10 hit_mask = ~hit_mask for i in range(1, len(result)): if result[i] is not None: # if len(result[i].shape) == 2 and result[i].shape[0] == 3: # result[i] = result[i][:, hit_mask] # else: # result[i] = result[i][hit_mask] result[i] = result[i][hit_mask] hit_type_one_hot = torch.nn.functional.one_hot( hit_type_feature[~mask_hits][hit_mask], num_classes=5 ) result.append(hit_type_one_hot) result.append(connection_list) return result if graph_empty: return [None] def remove_hittype0(graph): filt = graph.ndata["hit_type"] == 0 # graph.ndata["hit_type"] -= 1 return dgl.remove_nodes(graph, torch.where(filt)[0]) def store_track_at_vertex_at_track_at_calo(graph): # To make it compatible with clustering, remove the 0 hit type nodes and store them as pos_pxpypz_at_vertex tracks_at_calo = graph.ndata["hit_type"] == 1 tracks_at_vertex = graph.ndata["hit_type"] == 0 part = graph.ndata["particle_number"].long() assert (part[tracks_at_calo] == part[tracks_at_vertex]).all() graph.ndata["pos_pxpypz_at_vertex"] = torch.zeros_like(graph.ndata["pos_pxpypz"]) graph.ndata["pos_pxpypz_at_vertex"][tracks_at_calo] = graph.ndata["pos_pxpypz"][tracks_at_vertex] return remove_hittype0(graph) def create_jets_outputs_Delphes2(output): # for the v2 data loading config n_pf = int(output["n_PFCands"][0, 0]) n_genp = int(output["NParticles"][0, 0]) genp = output["GenParticles"][:, :n_genp] pfcands = output["PFCands"][:, :n_pf] if pfcands.shape[1] < n_pf: n_pf = pfcands.shape[1] pfcands = output["PFCands"][:, :n_pf] genp = genp.T pfcands=pfcands.T genp_status = genp[:, 6] genp_eta = genp[:, 0] genp_pt = genp[:, 2] filter_dq = genp_status == 23 genp_pid = genp[:, 4] pfcands = EventPFCands( pt=pfcands[:, 2], eta=pfcands[:, 0], phi=pfcands[:, 1], mass=pfcands[:, 3], charge=pfcands[:, 4], pid=pfcands[:, 5], pf_cand_jet_idx=[-1]*len(pfcands) ) filter_pfcands = (pfcands.pt > 0.5) & (torch.abs(pfcands.eta) < 2.4) pfcands.mask(filter_pfcands) filter_partons = (genp_status >= 51) & (genp_status <= 59) & (np.abs(genp_eta) < 2.4) & (genp_pt > 0.5) matrix_element_gen_particles = EventPFCands( genp[filter_dq, 2], genp[filter_dq, 0], genp[filter_dq, 1], genp[filter_dq, 3], np.sign(genp[filter_dq, 4]), genp[filter_dq, 5], pf_cand_jet_idx=-1 * np.ones_like(genp[filter_dq, 0]), ) parton_level_particles = EventPFCands( genp[filter_partons, 2], genp[filter_partons, 0], genp[filter_partons, 1], genp[filter_partons, 3], np.sign(genp[filter_partons, 4]), genp[filter_partons, 5], pf_cand_jet_idx=-1 * np.ones_like(genp[filter_partons, 0]), ) filter_final_gen_particles = (genp_status == 1) & (np.abs(genp_eta) < 2.4) & (genp_pt > 0.5) final_gen_particles = EventPFCands( genp[filter_final_gen_particles, 2], genp[filter_final_gen_particles, 0], genp[filter_final_gen_particles, 1], genp[filter_final_gen_particles, 3], np.sign(genp[filter_final_gen_particles, 4]), genp[filter_final_gen_particles, 5], pf_cand_jet_idx=-1 * np.ones_like(genp[filter_final_gen_particles, 0]), ) if len(final_gen_particles) == 0: print("No gen particles in this event?") print(genp_status, len(genp_status)) #print(genp_eta) return Event(pfcands=pfcands, matrix_element_gen_particles=matrix_element_gen_particles, final_gen_particles=final_gen_particles, final_parton_level_particles=parton_level_particles) def create_jets_outputs_Delphes(output): n_ch = int(output["n_CH"][0, 0]) n_nh = int(output["n_NH"][0, 0]) n_photons = int(output["n_photon"][0, 0]) n_genp = int(output["NParticles"][0, 0]) ch = output["CH"][:, :n_ch] nh = output["NH"][:, :n_nh] photons = output["EFlowPhoton"][:, :n_photons] genp = output["GenParticles"][:, :n_genp] if nh.shape[1] < n_nh: n_nh = nh.shape[1] if ch.shape[1] < n_ch: n_ch = ch.shape[1] if photons.shape[1] < n_photons: n_photons = photons.shape[1] nh_mass = [0.135] * n_nh # pion mass hypothesis nh_ET = nh[2, :] nh_pt = np.sqrt(nh_ET ** 2 - np.array(nh_mass)**2) # set nans to just et nh_pt[np.isnan(nh_pt)] = nh_ET[np.isnan(nh_pt)] nh_charge = [0] * n_nh nh_pid = [2112] * n_nh nh_jets = [-1] * n_nh ch_charge = ch[4, :] ch_pid = [211] * n_ch ch_jets = [-1] * n_ch photons_jets = [-1] * n_photons photons_mass = [0] * n_photons photons_charge = [0] * n_photons photons_pid = [22] * n_photons nh = nh.T ch = ch.T photons = photons.T genp = genp.T nh_data = EventPFCands(nh_ET, nh[:, 0], nh[:, 1], nh_mass, nh_charge, nh_pid, pf_cand_jet_idx=nh_jets) ch_data = EventPFCands(ch[:, 2], ch[:, 0], ch[:, 1], ch[:, 3], ch_charge, ch_pid, pf_cand_jet_idx=ch_jets) photon_data = EventPFCands(photons[:, 2], photons[:, 0], photons[:, 1], photons_mass, photons_charge, photons_pid, pf_cand_jet_idx=photons_jets) pfcands = concat_event_collection([nh_data, ch_data, photon_data], nobatch=1) filter_pfcands = (pfcands.pt > 0.5) & (torch.abs(pfcands.eta) < 2.4) pfcands.mask(filter_pfcands) genp_status = genp[:, 6] genp_eta = genp[:, 0] genp_pt = genp[:, 2] filter_dq = genp_status == 23 genp_pid = genp[:, 4] filter_partons = (genp_status >= 51) & (genp_status <= 59) & (np.abs(genp_eta) < 2.4) & (genp_pt > 0.5) matrix_element_gen_particles = EventPFCands( genp[filter_dq, 2], genp[filter_dq, 0], genp[filter_dq, 1], genp[filter_dq, 3], np.sign(genp[filter_dq, 4]), genp[filter_dq, 5], pf_cand_jet_idx=-1 * np.ones_like(genp[filter_dq, 0]), ) parton_level_particles = EventPFCands( genp[filter_partons, 2], genp[filter_partons, 0], genp[filter_partons, 1], genp[filter_partons, 3], np.sign(genp[filter_partons, 4]), genp[filter_partons, 5], pf_cand_jet_idx=-1 * np.ones_like(genp[filter_partons, 0]), ) filter_final_gen_particles = (genp_status == 1) & (np.abs(genp_eta) < 2.4) & (genp_pt > 0.5) final_gen_particles = EventPFCands( genp[filter_final_gen_particles, 2], genp[filter_final_gen_particles, 0], genp[filter_final_gen_particles, 1], genp[filter_final_gen_particles, 3], np.sign(genp[filter_final_gen_particles, 4]), genp[filter_final_gen_particles, 5], pf_cand_jet_idx=-1 * np.ones_like(genp[filter_final_gen_particles, 0]), ) if len(final_gen_particles) == 0: print("No gen particles in this event?") print(genp_status, len(genp_status)) #print(genp_eta) return Event(pfcands=pfcands, matrix_element_gen_particles=matrix_element_gen_particles, final_gen_particles=final_gen_particles, final_parton_level_particles=parton_level_particles) def create_jets_outputs( output, config=None, ): n_jets = int(output["n_jets"][0, 0]) jets_data = output["jets"][:, :n_jets] n_genjets = int(output["n_genjets"][0, 0]) genjets_data = output["genjets"][:, :n_genjets] n_pfcands = int(output["n_pfcands"][0, 0]) n_fat_jets = int(output["n_fat_jets"][0, 0]) fat_jets_data = output["fat_jets"][:, :n_fat_jets] #jets_data = EventJets(jets_data[:, 0], ) return jets_data, genjets_data, fat_jets_data def create_jets_outputs_new( output, separate_special_pfcands=False ): print(output) n_jets = int(output["n_jets"][0, 0]) jets_data = output["jets"][:, :n_jets] n_genjets = int(output["n_genjets"][0, 0]) genjets_data = output["genjets"][:, :n_genjets] n_pfcands = int(output["n_pfcands"][0, 0]) pfcands_data = output["pfcands"][:, :n_pfcands] pfcands_jets_mapping = output["pfcands_jet_mapping"] output_MET = output["MET"] n_fat_jets = int(output["n_fat_jets"][0, 0]) fat_jets_data = output["fat_jets"][:, :n_fat_jets] num_mapping = np.argmax(pfcands_jets_mapping[1]) + 1 if n_jets == 0: num_mapping = 0 n_electrons = int(output["n_electrons"][0, 0]) electrons_data = output["electrons"][:, :n_electrons] n_muons = int(output["n_muons"][0, 0]) muons_data = output["muons"][:, :n_muons] n_photons = int(output["n_photons"][0, 0]) photons_data = output["photons"][:, :n_photons] matrix_element_gen_particles_data = output["matrix_element_gen_particles"] if "final_gen_particles" in output: # new config #n_final_gen_particles = int(output["n_final_gen_particles"][0, 0]) final_gen_particles_data = output["final_gen_particles"]#[:, :n_final_gen_particles] final_parton_level_particles_data = output["final_parton_level_particles"]#[:, :n_final_gen_particles] pfcands_jets_mapping = pfcands_jets_mapping[:, :num_mapping] #n_offline_pfcands = int(output["n_offline_pfcands"][0, 0]) #offline_pfcands_data = output["offline_pfcands"][:, :n_offline_pfcands] #offline_jets_mapping = output["offline_pfcands_jet_mapping"] #num_mapping_offline = np.argmax(offline_jets_mapping[1]) + 1 #assert offline_jets_mapping[1].max() < n_offline_pfcands if len(pfcands_jets_mapping[1]): assert pfcands_jets_mapping[1].max() < n_pfcands #offline_jets_mapping = offline_jets_mapping[:, :num_mapping_offline] jets_data = jets_data.T genjets_data = genjets_data.T pfcands_data = pfcands_data.T fat_jets_data = fat_jets_data.T matrix_element_gen_particles_data = matrix_element_gen_particles_data.T matrix_element_gen_particles_data = EventPFCands(pt=matrix_element_gen_particles_data[:, 0], eta=matrix_element_gen_particles_data[:, 1], phi=matrix_element_gen_particles_data[:, 2], mass=matrix_element_gen_particles_data[:, 3], charge=np.sign(matrix_element_gen_particles_data[:, 4]), pid=matrix_element_gen_particles_data[:, 4], pf_cand_jet_idx=-1*np.ones_like(matrix_element_gen_particles_data[:, 0])) if "final_gen_particles" in output: final_gen_particles_data = final_gen_particles_data.T final_parton_level_particles_data = final_parton_level_particles_data.T n_fp = torch.argmin(torch.tensor(final_gen_particles_data[:, 0])).item() n_pp = torch.argmin(torch.tensor(final_parton_level_particles_data[:, 0])).item() final_gen_particles_data = EventPFCands(pt=final_gen_particles_data[:n_fp, 0], eta=final_gen_particles_data[:n_fp, 1], phi=final_gen_particles_data[:n_fp, 2], mass=final_gen_particles_data[:n_fp, 3], charge=np.sign(final_gen_particles_data[:n_fp, 4]), pid=final_gen_particles_data[:n_fp, 4], pf_cand_jet_idx=-1*np.ones_like(final_gen_particles_data[:n_fp, 0])) final_parton_level_particles_data = EventPFCands(pt=final_parton_level_particles_data[:n_pp, 0], eta=final_parton_level_particles_data[:n_pp, 1], phi=final_parton_level_particles_data[:n_pp, 2], mass=final_parton_level_particles_data[:n_pp, 3], charge=np.sign(final_parton_level_particles_data[:n_pp, 4]), pid=final_parton_level_particles_data[:n_pp, 4], pf_cand_jet_idx=-1*np.ones_like(final_parton_level_particles_data[:n_pp, 0]), status=final_parton_level_particles_data[:n_pp, 5]) #offline_pfcands_data = offline_pfcands_data.T electrons_data = electrons_data.T muons_data = muons_data.T photons_data = photons_data.T electrons_mass = np.ones_like(electrons_data[:, 0]) * 0.511 muons_mass = np.ones_like(muons_data[:, 0]) * 105.7 photons_mass = np.zeros_like(photons_data[:, 0]) electrons_pid = np.ones_like(electrons_data[:, 0]) * 0 muons_pid = np.ones_like(muons_data[:, 0]) * 1 photons_pid = np.ones_like(photons_data[:, 0]) * 2 photons_charge = np.zeros_like(photons_data[:, 0]) electrons_data = np.column_stack((electrons_data[:, 0], electrons_data[:, 1], electrons_data[:, 2], electrons_mass, electrons_data[:, 3], electrons_pid)) muons_data = np.column_stack((muons_data[:, 0], muons_data[:, 1], muons_data[:, 2], muons_mass, muons_data[:, 3], muons_pid)) photons_data = np.column_stack((photons_data[:, 0], photons_data[:, 1], photons_data[:, 2], photons_mass, photons_charge, photons_pid)) special_pfcands_data = np.concatenate((electrons_data, muons_data, photons_data), axis=0) special_pfcands_data = torch.tensor(special_pfcands_data) # is there jets_data = EventJets( jets_data[:, 0], jets_data[:, 1], jets_data[:, 2], jets_data[:, 3], #jets_data[:, 4] ) genjets_data = EventJets( genjets_data[:, 0], genjets_data[:, 1], genjets_data[:, 2], genjets_data[:, 3], ) fatjets_data = EventJets( fat_jets_data[:, 0], fat_jets_data[:, 1], fat_jets_data[:, 2], fat_jets_data[:, 3], #fat_jets_data[:, 4] ) pfcands_jets_mapping = list(pfcands_jets_mapping) #offline_jets_mapping = list(offline_jets_mapping) pfcands_data = EventPFCands(*[pfcands_data[:, i] for i in range(6)] + pfcands_jets_mapping) special_pfcands_data = EventPFCands(*[special_pfcands_data[:, i] for i in range(6)], pf_cand_jet_idx=-1*torch.ones_like(special_pfcands_data[:, 0])) if not separate_special_pfcands: pfcands_data = concat_event_collection([pfcands_data, special_pfcands_data]) special_pfcands_data = None MET_data = EventMetadataAndMET(pt=output_MET[0], phi=output_MET[1], scouting_trig=output_MET[2], offline_trig=output_MET[3], veto_trig=output_MET[4]) #offline_pfcands_data = EventPFCands(*[offline_pfcands_data[:, i] for i in range(6)] + offline_jets_mapping, offline=True) kwargs = {} if "final_gen_particles" in output: kwargs["final_gen_particles"] = final_gen_particles_data kwargs["final_parton_level_particles"] = final_parton_level_particles_data return Event(jets=jets_data, genjets=genjets_data, pfcands=pfcands_data, MET=MET_data, fatjets=fatjets_data, matrix_element_gen_particles=matrix_element_gen_particles_data, special_pfcands=special_pfcands_data, **kwargs) #return { # "jets": jets_data, # "genjets": genjets_data, # "pfcands": pfcands_data, # # "offline_pfcands": offline_pfcands_data #} def create_graph( output, config=None, n_noise=0, ): graph_empty = False hits_only = config.graph_config.get( "only_hits", False ) # Whether to only include hits in the graph # standardize_coords = config.graph_config.get("standardize_coords", False) extended_coords = config.graph_config.get("extended_coords", False) prediction = config.graph_config.get("prediction", False) hit_chis = config.graph_config.get("hit_chis_track", False) pos_pxpy = config.graph_config.get("pos_pxpy", False) is_Ks = config.graph_config.get("ks", False) noise_class = config.graph_config.get("noise", False) result = create_inputs_from_table( output, hits_only=hits_only, prediction=prediction, hit_chis=hit_chis, pos_pxpy=pos_pxpy, is_Ks=is_Ks, ) if len(result) == 1: graph_empty = True g = 0 y_data_graph = 0 else: ( y_data_graph, p_hits, e_hits, cluster_id, hit_particle_link, pos_xyz_hits, pos_pxpypz, pandora_cluster, pandora_cluster_energy, pandora_mom, pandora_ref_point, pandora_pid, pandora_pfo_energy, pandora_pfo_link, hit_type, hit_link_modified, daughters, chi_squared_tracks, hit_type_one_hot, connections_list ) = result if noise_class: mask_loopers, mask_particles = create_noise_label( e_hits, hit_particle_link, y_data_graph, cluster_id ) hit_particle_link[mask_loopers] = -1 y_data_graph.mask(mask_particles) cluster_id, unique_list_particles = find_cluster_id(hit_particle_link) graph_coordinates = pos_xyz_hits # / 3330 # divide by detector size graph_empty = False g = dgl.graph(([], [])) g.add_nodes(graph_coordinates.shape[0]) if hits_only == False: hit_features_graph = torch.cat( (graph_coordinates, hit_type_one_hot, e_hits, p_hits), dim=1 ) # dims = 8 else: hit_features_graph = torch.cat( (graph_coordinates, hit_type_one_hot, e_hits, p_hits), dim=1 ) # dims = 9 g.ndata["h"] = hit_features_graph g.ndata["pos_hits_xyz"] = pos_xyz_hits g.ndata["pos_pxpypz"] = pos_pxpypz g = calculate_distance_to_boundary(g) g.ndata["hit_type"] = hit_type g.ndata[ "e_hits" ] = e_hits # if no tracks this is e and if there are tracks this fills the tracks e values with p if hit_chis: g.ndata["chi_squared_tracks"] = chi_squared_tracks g.ndata["particle_number"] = cluster_id g.ndata["hit_link_modified"] = hit_link_modified g.ndata["particle_number_nomap"] = hit_particle_link if prediction: g.ndata["pandora_cluster"] = pandora_cluster g.ndata["pandora_pfo"] = pandora_pfo_link g.ndata["pandora_cluster_energy"] = pandora_cluster_energy g.ndata["pandora_pfo_energy"] = pandora_pfo_energy if is_Ks: g.ndata["pandora_momentum"] = pandora_mom g.ndata["pandora_reference_point"] = pandora_ref_point g.ndata["daughters"] = daughters g.ndata["pandora_pid"] = pandora_pid y_data_graph.calculate_corrected_E(g, connections_list) # if is_Ks == True: # if y_data_graph.pid.flatten().shape[0] == 4 and np.count_nonzero(y_data_graph.pid.flatten() == 22) == 4: # graph_empty = False # else: # graph_empty = True # if g.ndata["h"].shape[0] < 10 or (set(g.ndata["hit_type"].unique().tolist()) == set([0, 1]) and g.ndata["hit_type"][g.ndata["hit_type"] == 1].shape[0] < 10): # graph_empty = True # less than 10 hits # print("y len", len(y_data_graph)) # if is_Ks == False: # if len(y_data_graph) < 4: # graph_empty = True if pos_xyz_hits.shape[0] < 10: graph_empty = True if graph_empty: return [g, y_data_graph], graph_empty # print("graph_empty",graph_empty) g = store_track_at_vertex_at_track_at_calo(g) if noise_class: g = make_bad_tracks_noise_tracks(g) return [g, y_data_graph], graph_empty def graph_batch_func(list_graphs): """collator function for graph dataloader Args: list_graphs (list): list of graphs from the iterable dataset Returns: batch dgl: dgl batch of graphs """ list_graphs_g = [el[0] for el in list_graphs] # list_y = add_batch_number(list_graphs) # ys = torch.cat(list_y, dim=0) # ys = torch.reshape(ys, [-1, list_y[0].shape[1]]) ys = concatenate_Particles_GT(list_graphs) bg = dgl.batch(list_graphs_g) # reindex particle number return bg, ys def make_bad_tracks_noise_tracks(g): # is_chardged =scatter_add((g.ndata["hit_type"]==1).view(-1), g.ndata["particle_number"].long())[1:] mask_hit_type_t1 = g.ndata["hit_type"]==2 mask_hit_type_t2 = g.ndata["hit_type"]==1 mask_all = mask_hit_type_t1 # the other error could come from no hits in the ECAL for a cluster mean_pos_cluster = scatter_mean(g.ndata["pos_hits_xyz"][mask_all], g.ndata["particle_number"][mask_all].long().view(-1), dim=0) pos_track = g.ndata["pos_hits_xyz"][mask_hit_type_t2] particle_track = g.ndata["particle_number"][mask_hit_type_t2] if torch.sum(g.ndata["particle_number"] == 0)==0: #then index 1 is at 0 mean_pos_cluster = mean_pos_cluster[1:,:] particle_track = particle_track-1 # print(mean_pos_cluster.shape, torch.unique(g.ndata["particle_number"]).shape) # print("mean_pos_cluster", mean_pos_cluster.shape) # print("particle_track", particle_track) # print("pos_track", pos_track.shape) if mean_pos_cluster.shape[0] == torch.unique(g.ndata["particle_number"]).shape: distance_track_cluster = torch.norm(mean_pos_cluster[particle_track.long()]-pos_track,dim=1)/1000 # print("distance_track_cluster", distance_track_cluster) bad_tracks = distance_track_cluster>0.21 index_bad_tracks = mask_hit_type_t2.nonzero().view(-1)[bad_tracks] g.ndata["particle_number"][index_bad_tracks]= 0 return g