from src.models.transformer.tr_blocks import Transformer import torch import torch.nn as nn from xformers.ops.fmha import BlockDiagonalMask from torch_scatter import scatter_max, scatter_add, scatter_mean import numpy as np class TransformerModel(torch.nn.Module): def __init__(self, n_scalars, n_scalars_out, n_blocks, n_heads, internal_dim, obj_score, global_features_copy=False): super().__init__() self.n_scalars = n_scalars self.input_dim = n_scalars + 3 if obj_score: self.input_dim += 1 self.output_dim = 3 self.obj_score = obj_score #internal_dim = 128 #self.custom_decoder = nn.Linear(internal_dim, self.output_dim) #n_heads = 4 #self.transformer = nn.TransformerEncoder( # nn.TransformerEncoderLayer( # d_model=n_heads*self.input_dim, # nhead=n_heads, # dim_feedforward=internal_dim, # dropout=0.1, # activation="gelu", # ), # num_layers=4, #) if n_scalars_out > 0: self.output_dim += 1 # betas regression if self.obj_score: self.output_dim = 10 self.global_features_copy = global_features_copy self.transformer = Transformer( in_channels=self.input_dim, out_channels=self.output_dim, hidden_channels=internal_dim, num_heads=n_heads, num_blocks=n_blocks, ) if self.global_features_copy: self.transformer_global_features = Transformer( in_channels=self.input_dim, out_channels=self.output_dim, hidden_channels=internal_dim, num_heads=n_heads, num_blocks=n_blocks, ) self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1) if self.obj_score: factor = 1 if self.global_features_copy: factor = 2 self.final_mlp = nn.Sequential( nn.Linear(self.output_dim*factor, 10), nn.LeakyReLU(), nn.Linear(10, 1), ) #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False) def forward(self, data, data_events=None, data_events_clusters=None): # data: instance of EventBatch # data_events & data_events_clusters: Only relevant if --global-features-obj-score is on: data_events contains # the "unmodified" batch where the batch indices are if self.global_features_copy: assert data_events is not None and data_events_clusters is not None assert self.obj_score inputs_v = data_events.input_vectors.float() inputs_scalar = data_events.input_scalars.float() assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % ( self.n_scalars, inputs_scalar.shape[1]) inputs_transformer_events = torch.cat([inputs_scalar, inputs_v], dim=1) inputs_transformer_events = inputs_transformer_events.float() assert inputs_transformer_events.shape[1] == self.input_dim mask_global = self.build_attention_mask(data_events.batch_idx) x_global = inputs_transformer_events.unsqueeze(0) x_global = self.transformer_global_features(x_global, attention_mask=mask_global)[0] assert x_global.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x_global.shape[1]) assert x_global.shape[0] == x_global.shape[0], "Expected %d, got %d" % ( inputs_transformer_events.shape[0], x_global.shape[0]) m_global = scatter_mean(x_global, torch.tensor(data_events_clusters).to(x_global.device)+1, dim=0)[1:] inputs_v = data.input_vectors inputs_scalar = data.input_scalars assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (self.n_scalars, inputs_scalar.shape[1]) inputs_transformer = torch.cat([inputs_scalar, inputs_v], dim=1) inputs_transformer = inputs_transformer.float() print("input_dim", self.input_dim, inputs_transformer.shape) assert inputs_transformer.shape[1] == self.input_dim mask = self.build_attention_mask(data.batch_idx) x = inputs_transformer.unsqueeze(0) x = self.transformer(x, attention_mask=mask)[0] assert x.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x.shape[1]) assert x.shape[0] == inputs_transformer.shape[0], "Expected %d, got %d" % (inputs_transformer.shape[0], x.shape[0]) if not self.obj_score: x[:, -1] = torch.sigmoid(x[:, -1]) else: extract_from_virtual_nodes = False if extract_from_virtual_nodes: x = self.final_mlp(x[data.fake_nodes_idx]) # x is the raw logits else: m = scatter_mean(x, torch.tensor(data.batch_idx).long().to(x.device), dim=0) assert not "fake_nodes_idx" in data.__dict__ if self.global_features_copy: m = torch.cat([m, m_global], dim=1) x = self.final_mlp(m).flatten() return x def build_attention_mask(self, batch_numbers): return BlockDiagonalMask.from_seqlens( torch.bincount(batch_numbers.long()).tolist() ) def get_model(args, obj_score=False): n_scalars_out = 8 if args.beta_type == "pt": n_scalars_out = 0 elif args.beta_type == "pt+bc": n_scalars_out = 1 n_scalars_in = 12 if args.no_pid: n_scalars_in = 12-9 if obj_score: return TransformerModel( n_scalars=n_scalars_in, n_scalars_out=10, n_blocks=5, n_heads=args.n_heads, internal_dim=64, obj_score=obj_score, global_features_copy=args.global_features_obj_score ) return TransformerModel( n_scalars=n_scalars_in, n_scalars_out=n_scalars_out, n_blocks=args.num_blocks, n_heads=args.n_heads, internal_dim=args.internal_dim, obj_score=obj_score )