from gatr import GATr, SelfAttentionConfig, MLPConfig from gatr.interface import ( embed_point, extract_scalar, extract_point, embed_scalar, embed_translation, extract_translation ) import torch import torch.nn as nn from xformers.ops.fmha import BlockDiagonalMask class GATrModel(torch.nn.Module): def __init__(self, n_scalars, hidden_mv_channels, hidden_s_channels, blocks, embed_as_vectors, n_scalars_out): super().__init__() self.n_scalars = n_scalars self.hidden_mv_channels = hidden_mv_channels self.hidden_s_channels = hidden_s_channels self.blocks = blocks self.embed_as_vectors = embed_as_vectors self.input_dim = 3 self.n_scalars_out = n_scalars_out self.gatr = GATr( in_mv_channels=1, out_mv_channels=1, hidden_mv_channels=hidden_mv_channels, in_s_channels=n_scalars, out_s_channels=n_scalars_out, hidden_s_channels=hidden_s_channels, num_blocks=blocks, attention=SelfAttentionConfig(), # Use default parameters for attention mlp=MLPConfig(), # Use default parameters for MLP ) self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1) #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False) if n_scalars_out > 0: self.beta = nn.Linear(n_scalars_out + 1, 1) else: self.beta = None def forward(self, data): # data: instance of EventBatch inputs_v = data.input_vectors.float() inputs_scalar = data.input_scalars.float() assert inputs_scalar.shape[1] == self.n_scalars if self.embed_as_vectors: velocities = embed_translation(inputs_v) embedded_inputs = ( velocities ) # if it contains nans, raise an error if torch.isnan(embedded_inputs).any(): raise ValueError("NaNs in the input!") else: inputs = inputs_v embedded_inputs = embed_point(inputs) embedded_inputs = embedded_inputs.unsqueeze(-2) # (batch_size*num_points, 1, 16) mask = self.build_attention_mask(data.batch_idx) embedded_outputs, output_scalars = self.gatr( embedded_inputs, scalars=inputs_scalar, attention_mask=mask ) #if self.embed_as_vectors: # x_clusters = extract_translation(embedded_outputs) #else: # x_clusters = extract_point(embedded_outputs) if self.embed_as_vectors: x_clusters = extract_translation(embedded_outputs) else: x_clusters = extract_point(embedded_outputs) original_scalar = extract_scalar(embedded_outputs) if self.beta is not None: beta = self.beta(torch.cat([original_scalar[:, 0, :], output_scalars], dim=1)) x = torch.cat((x_clusters[:, 0, :], torch.sigmoid(beta.view(-1, 1))), dim=1) else: x = x_clusters[:, 0, :] if torch.isnan(x).any(): raise ValueError("NaNs in the output!") #print(x[:5]) return x def build_attention_mask(self, batch_numbers): return BlockDiagonalMask.from_seqlens( torch.bincount(batch_numbers.long()).tolist() ) def get_model(args, obj_score=False): n_scalars_out = 8 if args.beta_type == "pt": n_scalars_out = 0 elif args.beta_type == "pt+bc": n_scalars_out = 8 n_scalars_in = 12 if args.no_pid: n_scalars_in = 12-9 return GATrModel( n_scalars=n_scalars_in, hidden_mv_channels=args.hidden_mv_channels, hidden_s_channels=args.hidden_s_channels, blocks=args.num_blocks, embed_as_vectors=args.embed_as_vectors, n_scalars_out=n_scalars_out )