gregorkrzmanc's picture
.
4a9e89a
from lgatr import GATr, SelfAttentionConfig, MLPConfig
from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
import torch
import torch.nn as nn
from xformers.ops.fmha import BlockDiagonalMask
from torch_scatter import scatter_sum, scatter_max, scatter_mean
class LGATrModel(torch.nn.Module):
def __init__(self, n_scalars, hidden_mv_channels, hidden_s_channels, blocks, embed_as_vectors, n_scalars_out, return_scalar_coords, obj_score=False, global_featuers_copy=False):
super().__init__()
self.return_scalar_coords = return_scalar_coords
self.n_scalars = n_scalars
self.hidden_mv_channels = hidden_mv_channels
self.hidden_s_channels = hidden_s_channels
self.blocks = blocks
self.embed_as_vectors = embed_as_vectors
self.input_dim = 3
self.n_scalars_out = n_scalars_out
self.obj_score = obj_score
self.global_features_copy = global_featuers_copy
self.gatr = GATr(
in_mv_channels=3,
out_mv_channels=1,
hidden_mv_channels=hidden_mv_channels,
in_s_channels=n_scalars,
out_s_channels=n_scalars_out,
hidden_s_channels=hidden_s_channels,
num_blocks=blocks,
attention=SelfAttentionConfig(), # Use default parameters for attention
mlp=MLPConfig(), # Use default parameters for MLP
)
if self.global_features_copy:
self.gatr_global_features = GATr(
in_mv_channels=3,
out_mv_channels=1,
hidden_mv_channels=hidden_mv_channels,
in_s_channels=n_scalars,
out_s_channels=n_scalars_out,
hidden_s_channels=hidden_s_channels,
num_blocks=blocks,
attention=SelfAttentionConfig(), # Use default parameters for attention
mlp=MLPConfig(), # Use default parameters for MLP
)
#self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
#self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
if n_scalars_out > 0:
if obj_score:
factor = 1
if self.global_features_copy: factor = 2
self.beta = nn.Sequential(
nn.Linear((n_scalars_out + 1) * factor, 10),
nn.LeakyReLU(),
nn.Linear(10, 1),
#nn.Sigmoid()
)
else:
self.beta = nn.Linear(n_scalars_out + 1, 1)
else:
self.beta = None
def forward(self, data, data_events=None, data_events_clusters=None, cpu_demo=False):
# data: instance of EventBatch
if self.global_features_copy:
assert data_events is not None and data_events_clusters is not None
assert self.obj_score
inputs_v = data_events.input_vectors
inputs_scalar = data_events.input_scalars
assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (
self.n_scalars, inputs_scalar.shape[1])
mask_global = self.build_attention_mask(data_events.batch_idx)
embedded_inputs_events = embed_vector(inputs_v.unsqueeze(0))
multivectors = embedded_inputs_events.unsqueeze(-2)
spurions = embed_spurions(beam_reference="xyplane", add_time_reference=True,
device=multivectors.device, dtype=multivectors.dtype)
num_points, x = inputs_v.shape
assert x == 4
spurions = spurions[None, None, ...].repeat(1, num_points, 1, 1) # (batchsize, num_points, 2, 16)
multivectors = torch.cat((multivectors, spurions), dim=-2)
embedded_outputs, output_scalars = self.gatr_global_features(
multivectors, scalars=inputs_scalar, attention_mask=mask_global
)
original_scalar = extract_scalar(embedded_outputs)
scalar_embeddings_nodes = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1)
scalar_embeddings_global = scatter_mean(scalar_embeddings_nodes, torch.tensor(data_events_clusters).to(scalar_embeddings_nodes.device)+1, dim=0)[1:]
inputs_v = data.input_vectors.float() # four-momenta
inputs_scalar = data.input_scalars.float()
assert inputs_scalar.shape[1] == self.n_scalars
num_points, x = inputs_v.shape
assert x == 4
#velocities = embed_vector(inputs_v)
inputs_v = inputs_v.unsqueeze(0)
embedded_inputs = embed_vector(inputs_v)
# if it contains nans, raise an error
if torch.isnan(embedded_inputs).any():
raise ValueError("NaNs in the input!")
multivectors = embedded_inputs.unsqueeze(-2) # (batch_size*num_points, 1, 16)
# for spurions, duplicate each unique batch_idx. e.g. [0,0,1,1,2,2] etc.
#spurions_batch_idx = torch.repeat_interleave(data.batch_idx.unique(), 2)
#batch_idx = torch.cat([data.batch_idx, spurions_batch_idx])
spurions = embed_spurions(beam_reference="xyplane", add_time_reference=True,
device=multivectors.device, dtype=multivectors.dtype)
spurions = spurions[None, None, ...].repeat(1, num_points, 1, 1) # (batchsize, num_points, 2, 16)
multivectors = torch.cat((multivectors, spurions), dim=-2) # (batchsize, num_points, 3, 16) - Just embed the spurions as two extra multivector channels
mask = self.build_attention_mask(data.batch_idx)
if cpu_demo:
mask = None
embedded_outputs, output_scalars = self.gatr(
multivectors, scalars=inputs_scalar, attention_mask=mask
)
#if self.embed_as_vectors:
# x_clusters = extract_translation(embedded_outputs)
#else:
# x_clusters = extract_point(embedded_outputs)
x_clusters = extract_vector(embedded_outputs)
original_scalar = extract_scalar(embedded_outputs)
if self.beta is not None:
if self.obj_score:
extract_from_virtual_nodes = False
# assert that data has fake_nodes_idx from which we read the objectness score
#assert "fake_nodes_idx" in data.__dict__
# print batch number 3 and 4 inputs
#for nbatch in [3, 4]:
# print("#### batch no. ", nbatch , "#######")
# print(" -> scalar inputs", inputs_scalar[data.batch_idx==nbatch].shape, inputs_scalar[data.batch_idx == nbatch])
# print(" -> vector inputs", data.input_vectors[data.batch_idx==nbatch].shape, data.input_vectors[data.batch_idx == nbatch])
# print("############")
scalar_embeddings = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1)
if extract_from_virtual_nodes:
values = torch.cat([original_scalar[0, data.fake_nodes_idx, 0, :], output_scalars[0, data.fake_nodes_idx, :]], dim=1)
else:
values = scatter_mean(scalar_embeddings, data.batch_idx.to(scalar_embeddings.device).long(), dim=0)
if self.global_features_copy:
values = torch.cat([values, scalar_embeddings_global], dim=1)
beta = self.beta(values)
#beta = self.beta(values)
return beta
vals = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1)
beta = self.beta(vals)
if self.return_scalar_coords:
x = output_scalars[0, :, :3]
#print(x.shape)
#print(x[:5])
x = torch.cat((x, torch.sigmoid(beta.view(-1, 1))), dim=1)
else:
x = torch.cat((x_clusters[0, :, 0, :], torch.sigmoid(beta.view(-1, 1))), dim=1)
else:
x = x_clusters[:, 0, :]
if torch.isnan(x).any():
raise ValueError("NaNs in the output!")
#print(x[:5])
print("LGATr x shape:", x.shape)
return x
def build_attention_mask(self, batch_numbers):
return BlockDiagonalMask.from_seqlens(
torch.bincount(batch_numbers.long()).tolist()
)
def get_model(args, obj_score=False):
n_scalars_out = 8
if args.beta_type == "pt":
n_scalars_out = 0
elif args.beta_type == "pt+bc":
n_scalars_out = 8
n_scalars_in = 12
if args.no_pid:
n_scalars_in = 12 - 9
if obj_score:
return LGATrModel(
n_scalars=n_scalars_in,
hidden_mv_channels=8,
hidden_s_channels=16,
blocks=5,
embed_as_vectors=False,
n_scalars_out=n_scalars_out,
return_scalar_coords=args.scalars_oc,
obj_score=obj_score,
global_featuers_copy=args.global_features_obj_score
)
return LGATrModel(
n_scalars=n_scalars_in,
hidden_mv_channels=args.hidden_mv_channels,
hidden_s_channels=args.hidden_s_channels,
blocks=args.num_blocks,
embed_as_vectors=args.embed_as_vectors,
n_scalars_out=n_scalars_out,
return_scalar_coords=args.scalars_oc,
obj_score=obj_score
)