from lgatr import GATr, SelfAttentionConfig, MLPConfig from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector import torch import torch.nn as nn from xformers.ops.fmha import BlockDiagonalMask from torch_scatter import scatter_sum, scatter_max, scatter_mean class LGATrModel(torch.nn.Module): def __init__(self, n_scalars, hidden_mv_channels, hidden_s_channels, blocks, embed_as_vectors, n_scalars_out, return_scalar_coords, obj_score=False, global_featuers_copy=False): super().__init__() self.return_scalar_coords = return_scalar_coords self.n_scalars = n_scalars self.hidden_mv_channels = hidden_mv_channels self.hidden_s_channels = hidden_s_channels self.blocks = blocks self.embed_as_vectors = embed_as_vectors self.input_dim = 3 self.n_scalars_out = n_scalars_out self.obj_score = obj_score self.global_features_copy = global_featuers_copy self.gatr = GATr( in_mv_channels=3, out_mv_channels=1, hidden_mv_channels=hidden_mv_channels, in_s_channels=n_scalars, out_s_channels=n_scalars_out, hidden_s_channels=hidden_s_channels, num_blocks=blocks, attention=SelfAttentionConfig(), # Use default parameters for attention mlp=MLPConfig(), # Use default parameters for MLP ) if self.global_features_copy: self.gatr_global_features = GATr( in_mv_channels=3, out_mv_channels=1, hidden_mv_channels=hidden_mv_channels, in_s_channels=n_scalars, out_s_channels=n_scalars_out, hidden_s_channels=hidden_s_channels, num_blocks=blocks, attention=SelfAttentionConfig(), # Use default parameters for attention mlp=MLPConfig(), # Use default parameters for MLP ) #self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1) #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False) if n_scalars_out > 0: if obj_score: factor = 1 if self.global_features_copy: factor = 2 self.beta = nn.Sequential( nn.Linear((n_scalars_out + 1) * factor, 10), nn.LeakyReLU(), nn.Linear(10, 1), #nn.Sigmoid() ) else: self.beta = nn.Linear(n_scalars_out + 1, 1) else: self.beta = None def forward(self, data, data_events=None, data_events_clusters=None, cpu_demo=False): # data: instance of EventBatch if self.global_features_copy: assert data_events is not None and data_events_clusters is not None assert self.obj_score inputs_v = data_events.input_vectors inputs_scalar = data_events.input_scalars assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % ( self.n_scalars, inputs_scalar.shape[1]) mask_global = self.build_attention_mask(data_events.batch_idx) embedded_inputs_events = embed_vector(inputs_v.unsqueeze(0)) multivectors = embedded_inputs_events.unsqueeze(-2) spurions = embed_spurions(beam_reference="xyplane", add_time_reference=True, device=multivectors.device, dtype=multivectors.dtype) num_points, x = inputs_v.shape assert x == 4 spurions = spurions[None, None, ...].repeat(1, num_points, 1, 1) # (batchsize, num_points, 2, 16) multivectors = torch.cat((multivectors, spurions), dim=-2) embedded_outputs, output_scalars = self.gatr_global_features( multivectors, scalars=inputs_scalar, attention_mask=mask_global ) original_scalar = extract_scalar(embedded_outputs) scalar_embeddings_nodes = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1) scalar_embeddings_global = scatter_mean(scalar_embeddings_nodes, torch.tensor(data_events_clusters).to(scalar_embeddings_nodes.device)+1, dim=0)[1:] inputs_v = data.input_vectors.float() # four-momenta inputs_scalar = data.input_scalars.float() assert inputs_scalar.shape[1] == self.n_scalars num_points, x = inputs_v.shape assert x == 4 #velocities = embed_vector(inputs_v) inputs_v = inputs_v.unsqueeze(0) embedded_inputs = embed_vector(inputs_v) # if it contains nans, raise an error if torch.isnan(embedded_inputs).any(): raise ValueError("NaNs in the input!") multivectors = embedded_inputs.unsqueeze(-2) # (batch_size*num_points, 1, 16) # for spurions, duplicate each unique batch_idx. e.g. [0,0,1,1,2,2] etc. #spurions_batch_idx = torch.repeat_interleave(data.batch_idx.unique(), 2) #batch_idx = torch.cat([data.batch_idx, spurions_batch_idx]) spurions = embed_spurions(beam_reference="xyplane", add_time_reference=True, device=multivectors.device, dtype=multivectors.dtype) spurions = spurions[None, None, ...].repeat(1, num_points, 1, 1) # (batchsize, num_points, 2, 16) multivectors = torch.cat((multivectors, spurions), dim=-2) # (batchsize, num_points, 3, 16) - Just embed the spurions as two extra multivector channels mask = self.build_attention_mask(data.batch_idx) if cpu_demo: mask = None embedded_outputs, output_scalars = self.gatr( multivectors, scalars=inputs_scalar, attention_mask=mask ) #if self.embed_as_vectors: # x_clusters = extract_translation(embedded_outputs) #else: # x_clusters = extract_point(embedded_outputs) x_clusters = extract_vector(embedded_outputs) original_scalar = extract_scalar(embedded_outputs) if self.beta is not None: if self.obj_score: extract_from_virtual_nodes = False # assert that data has fake_nodes_idx from which we read the objectness score #assert "fake_nodes_idx" in data.__dict__ # print batch number 3 and 4 inputs #for nbatch in [3, 4]: # print("#### batch no. ", nbatch , "#######") # print(" -> scalar inputs", inputs_scalar[data.batch_idx==nbatch].shape, inputs_scalar[data.batch_idx == nbatch]) # print(" -> vector inputs", data.input_vectors[data.batch_idx==nbatch].shape, data.input_vectors[data.batch_idx == nbatch]) # print("############") scalar_embeddings = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1) if extract_from_virtual_nodes: values = torch.cat([original_scalar[0, data.fake_nodes_idx, 0, :], output_scalars[0, data.fake_nodes_idx, :]], dim=1) else: values = scatter_mean(scalar_embeddings, data.batch_idx.to(scalar_embeddings.device).long(), dim=0) if self.global_features_copy: values = torch.cat([values, scalar_embeddings_global], dim=1) beta = self.beta(values) #beta = self.beta(values) return beta vals = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1) beta = self.beta(vals) if self.return_scalar_coords: x = output_scalars[0, :, :3] #print(x.shape) #print(x[:5]) x = torch.cat((x, torch.sigmoid(beta.view(-1, 1))), dim=1) else: x = torch.cat((x_clusters[0, :, 0, :], torch.sigmoid(beta.view(-1, 1))), dim=1) else: x = x_clusters[:, 0, :] if torch.isnan(x).any(): raise ValueError("NaNs in the output!") #print(x[:5]) print("LGATr x shape:", x.shape) return x def build_attention_mask(self, batch_numbers): return BlockDiagonalMask.from_seqlens( torch.bincount(batch_numbers.long()).tolist() ) def get_model(args, obj_score=False): n_scalars_out = 8 if args.beta_type == "pt": n_scalars_out = 0 elif args.beta_type == "pt+bc": n_scalars_out = 8 n_scalars_in = 12 if args.no_pid: n_scalars_in = 12 - 9 if obj_score: return LGATrModel( n_scalars=n_scalars_in, hidden_mv_channels=8, hidden_s_channels=16, blocks=5, embed_as_vectors=False, n_scalars_out=n_scalars_out, return_scalar_coords=args.scalars_oc, obj_score=obj_score, global_featuers_copy=args.global_features_obj_score ) return LGATrModel( n_scalars=n_scalars_in, hidden_mv_channels=args.hidden_mv_channels, hidden_s_channels=args.hidden_s_channels, blocks=args.num_blocks, embed_as_vectors=args.embed_as_vectors, n_scalars_out=n_scalars_out, return_scalar_coords=args.scalars_oc, obj_score=obj_score )