Spaces:
Sleeping
Sleeping
File size: 6,214 Bytes
b085dea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from src.models.transformer.tr_blocks import Transformer
import torch
import torch.nn as nn
from xformers.ops.fmha import BlockDiagonalMask
from torch_scatter import scatter_max, scatter_add, scatter_mean
import numpy as np
class TransformerModel(torch.nn.Module):
def __init__(self, n_scalars, n_scalars_out, n_blocks, n_heads, internal_dim, obj_score, global_features_copy=False):
super().__init__()
self.n_scalars = n_scalars
self.input_dim = n_scalars + 3
if obj_score:
self.input_dim += 1
self.output_dim = 3
self.obj_score = obj_score
#internal_dim = 128
#self.custom_decoder = nn.Linear(internal_dim, self.output_dim)
#n_heads = 4
#self.transformer = nn.TransformerEncoder(
# nn.TransformerEncoderLayer(
# d_model=n_heads*self.input_dim,
# nhead=n_heads,
# dim_feedforward=internal_dim,
# dropout=0.1,
# activation="gelu",
# ),
# num_layers=4,
#)
if n_scalars_out > 0:
self.output_dim += 1 # betas regression
if self.obj_score:
self.output_dim = 10
self.global_features_copy = global_features_copy
self.transformer = Transformer(
in_channels=self.input_dim,
out_channels=self.output_dim,
hidden_channels=internal_dim,
num_heads=n_heads,
num_blocks=n_blocks,
)
if self.global_features_copy:
self.transformer_global_features = Transformer(
in_channels=self.input_dim,
out_channels=self.output_dim,
hidden_channels=internal_dim,
num_heads=n_heads,
num_blocks=n_blocks,
)
self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
if self.obj_score:
factor = 1
if self.global_features_copy: factor = 2
self.final_mlp = nn.Sequential(
nn.Linear(self.output_dim*factor, 10),
nn.LeakyReLU(),
nn.Linear(10, 1),
)
#self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
def forward(self, data, data_events=None, data_events_clusters=None):
# data: instance of EventBatch
# data_events & data_events_clusters: Only relevant if --global-features-obj-score is on: data_events contains
# the "unmodified" batch where the batch indices are
if self.global_features_copy:
assert data_events is not None and data_events_clusters is not None
assert self.obj_score
inputs_v = data_events.input_vectors.float()
inputs_scalar = data_events.input_scalars.float()
assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (
self.n_scalars, inputs_scalar.shape[1])
inputs_transformer_events = torch.cat([inputs_scalar, inputs_v], dim=1)
inputs_transformer_events = inputs_transformer_events.float()
assert inputs_transformer_events.shape[1] == self.input_dim
mask_global = self.build_attention_mask(data_events.batch_idx)
x_global = inputs_transformer_events.unsqueeze(0)
x_global = self.transformer_global_features(x_global, attention_mask=mask_global)[0]
assert x_global.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x_global.shape[1])
assert x_global.shape[0] == x_global.shape[0], "Expected %d, got %d" % (
inputs_transformer_events.shape[0], x_global.shape[0])
m_global = scatter_mean(x_global, torch.tensor(data_events_clusters).to(x_global.device)+1, dim=0)[1:]
inputs_v = data.input_vectors
inputs_scalar = data.input_scalars
assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (self.n_scalars, inputs_scalar.shape[1])
inputs_transformer = torch.cat([inputs_scalar, inputs_v], dim=1)
inputs_transformer = inputs_transformer.float()
print("input_dim", self.input_dim, inputs_transformer.shape)
assert inputs_transformer.shape[1] == self.input_dim
mask = self.build_attention_mask(data.batch_idx)
x = inputs_transformer.unsqueeze(0)
x = self.transformer(x, attention_mask=mask)[0]
assert x.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x.shape[1])
assert x.shape[0] == inputs_transformer.shape[0], "Expected %d, got %d" % (inputs_transformer.shape[0], x.shape[0])
if not self.obj_score:
x[:, -1] = torch.sigmoid(x[:, -1])
else:
extract_from_virtual_nodes = False
if extract_from_virtual_nodes:
x = self.final_mlp(x[data.fake_nodes_idx]) # x is the raw logits
else:
m = scatter_mean(x, torch.tensor(data.batch_idx).long().to(x.device), dim=0)
assert not "fake_nodes_idx" in data.__dict__
if self.global_features_copy:
m = torch.cat([m, m_global], dim=1)
x = self.final_mlp(m).flatten()
return x
def build_attention_mask(self, batch_numbers):
return BlockDiagonalMask.from_seqlens(
torch.bincount(batch_numbers.long()).tolist()
)
def get_model(args, obj_score=False):
n_scalars_out = 8
if args.beta_type == "pt":
n_scalars_out = 0
elif args.beta_type == "pt+bc":
n_scalars_out = 1
n_scalars_in = 12
if args.no_pid:
n_scalars_in = 12-9
if obj_score:
return TransformerModel(
n_scalars=n_scalars_in,
n_scalars_out=10,
n_blocks=5,
n_heads=args.n_heads,
internal_dim=64,
obj_score=obj_score,
global_features_copy=args.global_features_obj_score
)
return TransformerModel(
n_scalars=n_scalars_in,
n_scalars_out=n_scalars_out,
n_blocks=args.num_blocks,
n_heads=args.n_heads,
internal_dim=args.internal_dim,
obj_score=obj_score
)
|