File size: 6,214 Bytes
b085dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from src.models.transformer.tr_blocks import Transformer
import torch
import torch.nn as nn
from xformers.ops.fmha import BlockDiagonalMask
from torch_scatter import scatter_max, scatter_add, scatter_mean
import numpy as np


class TransformerModel(torch.nn.Module):
    def __init__(self, n_scalars, n_scalars_out, n_blocks, n_heads, internal_dim, obj_score, global_features_copy=False):
        super().__init__()
        self.n_scalars = n_scalars
        self.input_dim = n_scalars + 3
        if obj_score:
            self.input_dim += 1
        self.output_dim = 3
        self.obj_score = obj_score
        #internal_dim = 128
        #self.custom_decoder = nn.Linear(internal_dim, self.output_dim)
        #n_heads = 4
        #self.transformer = nn.TransformerEncoder(
        #    nn.TransformerEncoderLayer(
        #        d_model=n_heads*self.input_dim,
        #        nhead=n_heads,
        #        dim_feedforward=internal_dim,
        #        dropout=0.1,
        #        activation="gelu",
        #    ),
        #    num_layers=4,
        #)
        if n_scalars_out > 0:
            self.output_dim += 1 # betas regression
        if self.obj_score:
            self.output_dim = 10
        self.global_features_copy = global_features_copy
        self.transformer = Transformer(
            in_channels=self.input_dim,
            out_channels=self.output_dim,
            hidden_channels=internal_dim,
            num_heads=n_heads,
            num_blocks=n_blocks,
        )
        if self.global_features_copy:
            self.transformer_global_features = Transformer(
                in_channels=self.input_dim,
                out_channels=self.output_dim,
                hidden_channels=internal_dim,
                num_heads=n_heads,
                num_blocks=n_blocks,
            )
        self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
        if self.obj_score:
            factor = 1
            if self.global_features_copy: factor = 2
            self.final_mlp = nn.Sequential(
                nn.Linear(self.output_dim*factor, 10),
                nn.LeakyReLU(),
                nn.Linear(10, 1),
            )
        #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)

    def forward(self, data, data_events=None, data_events_clusters=None):
        # data: instance of EventBatch
        # data_events & data_events_clusters: Only relevant if --global-features-obj-score is on: data_events contains
        # the "unmodified" batch where the batch indices are
        if self.global_features_copy:
            assert data_events is not None and data_events_clusters is not None
            assert self.obj_score
            inputs_v = data_events.input_vectors.float()
            inputs_scalar = data_events.input_scalars.float()
            assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (
            self.n_scalars, inputs_scalar.shape[1])
            inputs_transformer_events = torch.cat([inputs_scalar, inputs_v], dim=1)
            inputs_transformer_events = inputs_transformer_events.float()
            assert inputs_transformer_events.shape[1] == self.input_dim
            mask_global = self.build_attention_mask(data_events.batch_idx)
            x_global = inputs_transformer_events.unsqueeze(0)
            x_global = self.transformer_global_features(x_global, attention_mask=mask_global)[0]
            assert x_global.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x_global.shape[1])
            assert x_global.shape[0] == x_global.shape[0], "Expected %d, got %d" % (
            inputs_transformer_events.shape[0], x_global.shape[0])
            m_global = scatter_mean(x_global, torch.tensor(data_events_clusters).to(x_global.device)+1, dim=0)[1:]
        inputs_v = data.input_vectors
        inputs_scalar = data.input_scalars
        assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (self.n_scalars, inputs_scalar.shape[1])
        inputs_transformer = torch.cat([inputs_scalar, inputs_v], dim=1)
        inputs_transformer = inputs_transformer.float()
        print("input_dim", self.input_dim, inputs_transformer.shape)
        assert inputs_transformer.shape[1] == self.input_dim
        mask = self.build_attention_mask(data.batch_idx)
        x = inputs_transformer.unsqueeze(0)
        x = self.transformer(x, attention_mask=mask)[0]
        assert x.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x.shape[1])
        assert x.shape[0] == inputs_transformer.shape[0], "Expected %d, got %d" % (inputs_transformer.shape[0], x.shape[0])
        if not self.obj_score:
            x[:, -1] = torch.sigmoid(x[:, -1])
        else:
            extract_from_virtual_nodes = False
            if extract_from_virtual_nodes:
                x = self.final_mlp(x[data.fake_nodes_idx]) # x is the raw logits
            else:
                m = scatter_mean(x, torch.tensor(data.batch_idx).long().to(x.device), dim=0)
                assert not "fake_nodes_idx" in data.__dict__
                if self.global_features_copy:
                    m = torch.cat([m, m_global], dim=1)
                x = self.final_mlp(m).flatten()
        return x

    def build_attention_mask(self, batch_numbers):
        return BlockDiagonalMask.from_seqlens(
            torch.bincount(batch_numbers.long()).tolist()
        )

def get_model(args, obj_score=False):
    n_scalars_out = 8
    if args.beta_type == "pt":
        n_scalars_out = 0
    elif args.beta_type == "pt+bc":
        n_scalars_out = 1
    n_scalars_in = 12
    if args.no_pid:
        n_scalars_in = 12-9
    if obj_score:
        return TransformerModel(
            n_scalars=n_scalars_in,
            n_scalars_out=10,
            n_blocks=5,
            n_heads=args.n_heads,
            internal_dim=64,
            obj_score=obj_score,
            global_features_copy=args.global_features_obj_score
        )
    return TransformerModel(
        n_scalars=n_scalars_in,
        n_scalars_out=n_scalars_out,
        n_blocks=args.num_blocks,
        n_heads=args.n_heads,
        internal_dim=args.internal_dim,
        obj_score=obj_score
    )