File size: 3,878 Bytes
b085dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from gatr import GATr, SelfAttentionConfig, MLPConfig
from gatr.interface import (
    embed_point,
    extract_scalar,
    extract_point,
    embed_scalar,
    embed_translation,
    extract_translation
)
import torch
import torch.nn as nn
from xformers.ops.fmha import BlockDiagonalMask


class GATrModel(torch.nn.Module):
    def __init__(self, n_scalars, hidden_mv_channels, hidden_s_channels, blocks, embed_as_vectors, n_scalars_out):
        super().__init__()
        self.n_scalars = n_scalars
        self.hidden_mv_channels = hidden_mv_channels
        self.hidden_s_channels = hidden_s_channels
        self.blocks = blocks
        self.embed_as_vectors = embed_as_vectors
        self.input_dim = 3
        self.n_scalars_out = n_scalars_out
        self.gatr = GATr(
            in_mv_channels=1,
            out_mv_channels=1,
            hidden_mv_channels=hidden_mv_channels,
            in_s_channels=n_scalars,
            out_s_channels=n_scalars_out,
            hidden_s_channels=hidden_s_channels,
            num_blocks=blocks,
            attention=SelfAttentionConfig(),  # Use default parameters for attention
            mlp=MLPConfig(),  # Use default parameters for MLP
        )
        self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
        #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
        if n_scalars_out > 0:
            self.beta = nn.Linear(n_scalars_out + 1, 1)
        else:
            self.beta = None

    def forward(self, data):
        # data: instance of EventBatch
        inputs_v = data.input_vectors.float()
        inputs_scalar = data.input_scalars.float()
        assert inputs_scalar.shape[1] == self.n_scalars
        if self.embed_as_vectors:
            velocities = embed_translation(inputs_v)
            embedded_inputs = (
                velocities
            )
            # if it contains nans, raise an error
            if torch.isnan(embedded_inputs).any():
                raise ValueError("NaNs in the input!")
        else:
            inputs = inputs_v
            embedded_inputs = embed_point(inputs)
        embedded_inputs = embedded_inputs.unsqueeze(-2) # (batch_size*num_points, 1, 16)
        mask = self.build_attention_mask(data.batch_idx)
        embedded_outputs, output_scalars = self.gatr(
            embedded_inputs, scalars=inputs_scalar, attention_mask=mask
        )
        #if self.embed_as_vectors:
        #    x_clusters = extract_translation(embedded_outputs)
        #else:
        #    x_clusters = extract_point(embedded_outputs)
        if self.embed_as_vectors:
            x_clusters = extract_translation(embedded_outputs)
        else:
            x_clusters = extract_point(embedded_outputs)
        original_scalar = extract_scalar(embedded_outputs)
        if self.beta is not None:
            beta = self.beta(torch.cat([original_scalar[:, 0, :], output_scalars], dim=1))
            x = torch.cat((x_clusters[:, 0, :], torch.sigmoid(beta.view(-1, 1))), dim=1)
        else:
            x = x_clusters[:, 0, :]
        if torch.isnan(x).any():
            raise ValueError("NaNs in the output!")
        #print(x[:5])
        return x

    def build_attention_mask(self, batch_numbers):
        return BlockDiagonalMask.from_seqlens(
            torch.bincount(batch_numbers.long()).tolist()
        )

def get_model(args, obj_score=False):
    n_scalars_out = 8
    if args.beta_type == "pt":
        n_scalars_out = 0
    elif args.beta_type == "pt+bc":
        n_scalars_out = 8
    n_scalars_in = 12
    if args.no_pid:
        n_scalars_in = 12-9
    return GATrModel(
        n_scalars=n_scalars_in,
        hidden_mv_channels=args.hidden_mv_channels,
        hidden_s_channels=args.hidden_s_channels,
        blocks=args.num_blocks,
        embed_as_vectors=args.embed_as_vectors,
        n_scalars_out=n_scalars_out
    )