gregorkrzmanc commited on
Commit
b085dea
·
1 Parent(s): 08310aa
Dockerfile CHANGED
@@ -3,7 +3,7 @@
3
  FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
4
 
5
  WORKDIR /app
6
-
7
  COPY . /app
8
 
9
  SHELL ["/bin/bash", "-c"]
@@ -11,12 +11,10 @@ SHELL ["/bin/bash", "-c"]
11
  USER root
12
 
13
  RUN ls /app
14
- RUN echo "---"
15
  RUN ls /app/src
16
- RUN echo "----"
17
- RUN ls /app/src/models/
18
- RUN echo "----"
19
- RUN ls /app/src/models/lgatr
20
  RUN apt update && \
21
  DEBIAN_FRONTEND=noninteractive apt install --yes --no-install-recommends \
22
  build-essential \
 
3
  FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
4
 
5
  WORKDIR /app
6
+ RUN ls .
7
  COPY . /app
8
 
9
  SHELL ["/bin/bash", "-c"]
 
11
  USER root
12
 
13
  RUN ls /app
 
14
  RUN ls /app/src
15
+ RUN ls /app/src/1models/
16
+ RUN ls /app/src/1models/LGATr
17
+
 
18
  RUN apt update && \
19
  DEBIAN_FRONTEND=noninteractive apt install --yes --no-install-recommends \
20
  build-essential \
src/1models/GATr/Gatr.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gatr import GATr, SelfAttentionConfig, MLPConfig
2
+ from gatr.interface import (
3
+ embed_point,
4
+ extract_scalar,
5
+ extract_point,
6
+ embed_scalar,
7
+ embed_translation,
8
+ extract_translation
9
+ )
10
+ import torch
11
+ import torch.nn as nn
12
+ from xformers.ops.fmha import BlockDiagonalMask
13
+
14
+
15
+ class GATrModel(torch.nn.Module):
16
+ def __init__(self, n_scalars, hidden_mv_channels, hidden_s_channels, blocks, embed_as_vectors, n_scalars_out):
17
+ super().__init__()
18
+ self.n_scalars = n_scalars
19
+ self.hidden_mv_channels = hidden_mv_channels
20
+ self.hidden_s_channels = hidden_s_channels
21
+ self.blocks = blocks
22
+ self.embed_as_vectors = embed_as_vectors
23
+ self.input_dim = 3
24
+ self.n_scalars_out = n_scalars_out
25
+ self.gatr = GATr(
26
+ in_mv_channels=1,
27
+ out_mv_channels=1,
28
+ hidden_mv_channels=hidden_mv_channels,
29
+ in_s_channels=n_scalars,
30
+ out_s_channels=n_scalars_out,
31
+ hidden_s_channels=hidden_s_channels,
32
+ num_blocks=blocks,
33
+ attention=SelfAttentionConfig(), # Use default parameters for attention
34
+ mlp=MLPConfig(), # Use default parameters for MLP
35
+ )
36
+ self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
37
+ #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
38
+ if n_scalars_out > 0:
39
+ self.beta = nn.Linear(n_scalars_out + 1, 1)
40
+ else:
41
+ self.beta = None
42
+
43
+ def forward(self, data):
44
+ # data: instance of EventBatch
45
+ inputs_v = data.input_vectors.float()
46
+ inputs_scalar = data.input_scalars.float()
47
+ assert inputs_scalar.shape[1] == self.n_scalars
48
+ if self.embed_as_vectors:
49
+ velocities = embed_translation(inputs_v)
50
+ embedded_inputs = (
51
+ velocities
52
+ )
53
+ # if it contains nans, raise an error
54
+ if torch.isnan(embedded_inputs).any():
55
+ raise ValueError("NaNs in the input!")
56
+ else:
57
+ inputs = inputs_v
58
+ embedded_inputs = embed_point(inputs)
59
+ embedded_inputs = embedded_inputs.unsqueeze(-2) # (batch_size*num_points, 1, 16)
60
+ mask = self.build_attention_mask(data.batch_idx)
61
+ embedded_outputs, output_scalars = self.gatr(
62
+ embedded_inputs, scalars=inputs_scalar, attention_mask=mask
63
+ )
64
+ #if self.embed_as_vectors:
65
+ # x_clusters = extract_translation(embedded_outputs)
66
+ #else:
67
+ # x_clusters = extract_point(embedded_outputs)
68
+ if self.embed_as_vectors:
69
+ x_clusters = extract_translation(embedded_outputs)
70
+ else:
71
+ x_clusters = extract_point(embedded_outputs)
72
+ original_scalar = extract_scalar(embedded_outputs)
73
+ if self.beta is not None:
74
+ beta = self.beta(torch.cat([original_scalar[:, 0, :], output_scalars], dim=1))
75
+ x = torch.cat((x_clusters[:, 0, :], torch.sigmoid(beta.view(-1, 1))), dim=1)
76
+ else:
77
+ x = x_clusters[:, 0, :]
78
+ if torch.isnan(x).any():
79
+ raise ValueError("NaNs in the output!")
80
+ #print(x[:5])
81
+ return x
82
+
83
+ def build_attention_mask(self, batch_numbers):
84
+ return BlockDiagonalMask.from_seqlens(
85
+ torch.bincount(batch_numbers.long()).tolist()
86
+ )
87
+
88
+ def get_model(args, obj_score=False):
89
+ n_scalars_out = 8
90
+ if args.beta_type == "pt":
91
+ n_scalars_out = 0
92
+ elif args.beta_type == "pt+bc":
93
+ n_scalars_out = 8
94
+ n_scalars_in = 12
95
+ if args.no_pid:
96
+ n_scalars_in = 12-9
97
+ return GATrModel(
98
+ n_scalars=n_scalars_in,
99
+ hidden_mv_channels=args.hidden_mv_channels,
100
+ hidden_s_channels=args.hidden_s_channels,
101
+ blocks=args.num_blocks,
102
+ embed_as_vectors=args.embed_as_vectors,
103
+ n_scalars_out=n_scalars_out
104
+ )
src/1models/LGATr/lgatr.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lgatr import GATr, SelfAttentionConfig, MLPConfig
2
+ from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
3
+ import torch
4
+ import torch.nn as nn
5
+ from xformers.ops.fmha import BlockDiagonalMask
6
+ from torch_scatter import scatter_sum, scatter_max, scatter_mean
7
+
8
+
9
+ class LGATrModel(torch.nn.Module):
10
+ def __init__(self, n_scalars, hidden_mv_channels, hidden_s_channels, blocks, embed_as_vectors, n_scalars_out, return_scalar_coords, obj_score=False, global_featuers_copy=False):
11
+ super().__init__()
12
+ self.return_scalar_coords = return_scalar_coords
13
+ self.n_scalars = n_scalars
14
+ self.hidden_mv_channels = hidden_mv_channels
15
+ self.hidden_s_channels = hidden_s_channels
16
+ self.blocks = blocks
17
+ self.embed_as_vectors = embed_as_vectors
18
+ self.input_dim = 3
19
+ self.n_scalars_out = n_scalars_out
20
+ self.obj_score = obj_score
21
+ self.global_features_copy = global_featuers_copy
22
+ self.gatr = GATr(
23
+ in_mv_channels=3,
24
+ out_mv_channels=1,
25
+ hidden_mv_channels=hidden_mv_channels,
26
+ in_s_channels=n_scalars,
27
+ out_s_channels=n_scalars_out,
28
+ hidden_s_channels=hidden_s_channels,
29
+ num_blocks=blocks,
30
+ attention=SelfAttentionConfig(), # Use default parameters for attention
31
+ mlp=MLPConfig(), # Use default parameters for MLP
32
+ )
33
+ if self.global_features_copy:
34
+ self.gatr_global_features = GATr(
35
+ in_mv_channels=3,
36
+ out_mv_channels=1,
37
+ hidden_mv_channels=hidden_mv_channels,
38
+ in_s_channels=n_scalars,
39
+ out_s_channels=n_scalars_out,
40
+ hidden_s_channels=hidden_s_channels,
41
+ num_blocks=blocks,
42
+ attention=SelfAttentionConfig(), # Use default parameters for attention
43
+ mlp=MLPConfig(), # Use default parameters for MLP
44
+ )
45
+ #self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
46
+ #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
47
+ if n_scalars_out > 0:
48
+ if obj_score:
49
+ factor = 1
50
+ if self.global_features_copy: factor = 2
51
+ self.beta = nn.Sequential(
52
+ nn.Linear((n_scalars_out + 1) * factor, 10),
53
+ nn.LeakyReLU(),
54
+ nn.Linear(10, 1),
55
+ #nn.Sigmoid()
56
+ )
57
+ else:
58
+ self.beta = nn.Linear(n_scalars_out + 1, 1)
59
+ else:
60
+ self.beta = None
61
+
62
+ def forward(self, data, data_events=None, data_events_clusters=None, cpu_demo=False):
63
+ # data: instance of EventBatch
64
+ if self.global_features_copy:
65
+ assert data_events is not None and data_events_clusters is not None
66
+ assert self.obj_score
67
+ inputs_v = data_events.input_vectors
68
+ inputs_scalar = data_events.input_scalars
69
+ assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (
70
+ self.n_scalars, inputs_scalar.shape[1])
71
+ mask_global = self.build_attention_mask(data_events.batch_idx)
72
+ embedded_inputs_events = embed_vector(inputs_v.unsqueeze(0))
73
+ multivectors = embedded_inputs_events.unsqueeze(-2)
74
+ spurions = embed_spurions(beam_reference="xyplane", add_time_reference=True,
75
+ device=multivectors.device, dtype=multivectors.dtype)
76
+
77
+ num_points, x = inputs_v.shape
78
+ assert x == 4
79
+ spurions = spurions[None, None, ...].repeat(1, num_points, 1, 1) # (batchsize, num_points, 2, 16)
80
+ multivectors = torch.cat((multivectors, spurions), dim=-2)
81
+ embedded_outputs, output_scalars = self.gatr_global_features(
82
+ multivectors, scalars=inputs_scalar, attention_mask=mask_global
83
+ )
84
+ original_scalar = extract_scalar(embedded_outputs)
85
+ scalar_embeddings_nodes = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1)
86
+ scalar_embeddings_global = scatter_mean(scalar_embeddings_nodes, torch.tensor(data_events_clusters).to(scalar_embeddings_nodes.device)+1, dim=0)[1:]
87
+
88
+ inputs_v = data.input_vectors.float() # four-momenta
89
+ inputs_scalar = data.input_scalars.float()
90
+ assert inputs_scalar.shape[1] == self.n_scalars
91
+ num_points, x = inputs_v.shape
92
+ assert x == 4
93
+ #velocities = embed_vector(inputs_v)
94
+
95
+ inputs_v = inputs_v.unsqueeze(0)
96
+ embedded_inputs = embed_vector(inputs_v)
97
+ # if it contains nans, raise an error
98
+ if torch.isnan(embedded_inputs).any():
99
+ raise ValueError("NaNs in the input!")
100
+ multivectors = embedded_inputs.unsqueeze(-2) # (batch_size*num_points, 1, 16)
101
+ # for spurions, duplicate each unique batch_idx. e.g. [0,0,1,1,2,2] etc.
102
+ #spurions_batch_idx = torch.repeat_interleave(data.batch_idx.unique(), 2)
103
+ #batch_idx = torch.cat([data.batch_idx, spurions_batch_idx])
104
+ spurions = embed_spurions(beam_reference="xyplane", add_time_reference=True,
105
+ device=multivectors.device, dtype=multivectors.dtype)
106
+ spurions = spurions[None, None, ...].repeat(1, num_points, 1, 1) # (batchsize, num_points, 2, 16)
107
+ multivectors = torch.cat((multivectors, spurions), dim=-2) # (batchsize, num_points, 3, 16) - Just embed the spurions as two extra multivector channels
108
+ mask = self.build_attention_mask(data.batch_idx)
109
+ if cpu_demo:
110
+ mask = None
111
+ embedded_outputs, output_scalars = self.gatr(
112
+ multivectors, scalars=inputs_scalar, attention_mask=mask
113
+ )
114
+
115
+ #if self.embed_as_vectors:
116
+ # x_clusters = extract_translation(embedded_outputs)
117
+ #else:
118
+ # x_clusters = extract_point(embedded_outputs)
119
+ x_clusters = extract_vector(embedded_outputs)
120
+ original_scalar = extract_scalar(embedded_outputs)
121
+ if self.beta is not None:
122
+ if self.obj_score:
123
+ extract_from_virtual_nodes = False
124
+ # assert that data has fake_nodes_idx from which we read the objectness score
125
+ #assert "fake_nodes_idx" in data.__dict__
126
+ # print batch number 3 and 4 inputs
127
+ #for nbatch in [3, 4]:
128
+ # print("#### batch no. ", nbatch , "#######")
129
+ # print(" -> scalar inputs", inputs_scalar[data.batch_idx==nbatch].shape, inputs_scalar[data.batch_idx == nbatch])
130
+ # print(" -> vector inputs", data.input_vectors[data.batch_idx==nbatch].shape, data.input_vectors[data.batch_idx == nbatch])
131
+ # print("############")
132
+ scalar_embeddings = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1)
133
+ if extract_from_virtual_nodes:
134
+ values = torch.cat([original_scalar[0, data.fake_nodes_idx, 0, :], output_scalars[0, data.fake_nodes_idx, :]], dim=1)
135
+ else:
136
+ values = scatter_mean(scalar_embeddings, data.batch_idx.to(scalar_embeddings.device).long(), dim=0)
137
+ if self.global_features_copy:
138
+ values = torch.cat([values, scalar_embeddings_global], dim=1)
139
+ beta = self.beta(values)
140
+ #beta = self.beta(values)
141
+ return beta
142
+ vals = torch.cat([original_scalar[0, :, 0, :], output_scalars[0, :, :]], dim=1)
143
+ beta = self.beta(vals)
144
+ if self.return_scalar_coords:
145
+ x = output_scalars[0, :, :3]
146
+ #print(x.shape)
147
+ #print(x[:5])
148
+ x = torch.cat((x, torch.sigmoid(beta.view(-1, 1))), dim=1)
149
+ else:
150
+ x = torch.cat((x_clusters[0, :, 0, :], torch.sigmoid(beta.view(-1, 1))), dim=1)
151
+ else:
152
+ x = x_clusters[:, 0, :]
153
+ if torch.isnan(x).any():
154
+ raise ValueError("NaNs in the output!")
155
+ #print(x[:5])
156
+ print("LGATr x shape:", x.shape)
157
+ return x
158
+
159
+ def build_attention_mask(self, batch_numbers):
160
+ return BlockDiagonalMask.from_seqlens(
161
+ torch.bincount(batch_numbers.long()).tolist()
162
+ )
163
+
164
+ def get_model(args, obj_score=False):
165
+ n_scalars_out = 8
166
+ if args.beta_type == "pt":
167
+ n_scalars_out = 0
168
+ elif args.beta_type == "pt+bc":
169
+ n_scalars_out = 8
170
+ n_scalars_in = 12
171
+ if args.no_pid:
172
+ n_scalars_in = 12 - 9
173
+ if obj_score:
174
+ return LGATrModel(
175
+ n_scalars=n_scalars_in,
176
+ hidden_mv_channels=8,
177
+ hidden_s_channels=16,
178
+ blocks=5,
179
+ embed_as_vectors=False,
180
+ n_scalars_out=n_scalars_out,
181
+ return_scalar_coords=args.scalars_oc,
182
+ obj_score=obj_score,
183
+ global_featuers_copy=args.global_features_obj_score
184
+ )
185
+
186
+ return LGATrModel(
187
+ n_scalars=n_scalars_in,
188
+ hidden_mv_channels=args.hidden_mv_channels,
189
+ hidden_s_channels=args.hidden_s_channels,
190
+ blocks=args.num_blocks,
191
+ embed_as_vectors=args.embed_as_vectors,
192
+ n_scalars_out=n_scalars_out,
193
+ return_scalar_coords=args.scalars_oc,
194
+ obj_score=obj_score
195
+ )
196
+
src/1models/identity.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class IdentityModel(torch.nn.Module):
4
+ def __init__(self, n_out_coords=3):
5
+ super().__init__()
6
+ self.n_out_coords = n_out_coords
7
+
8
+ def forward(self, data):
9
+ # data: instance of EventBatch
10
+ inputs_v = data.input_vectors # four-momenta
11
+ betas = torch.ones(data.input_vectors.shape[0]).to(inputs_v.device)
12
+ norm_inputs_v = torch.norm(inputs_v, dim=1).unsqueeze(1)
13
+ #print("inputs_v.shape", inputs_v.shape)
14
+ #print("betas.shape", betas.shape)
15
+ #print("norm_inputs_v.shape", norm_inputs_v.shape)
16
+ #print("betas unsqueezed shape", betas.unsqueeze(1).shape)
17
+ x = torch.cat([inputs_v / norm_inputs_v, betas.unsqueeze(1)], dim=1)
18
+ return x
19
+
20
+
21
+ def get_model(args):
22
+ return IdentityModel()
src/1models/transformer/tr_blocks.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File copied from https://raw.githubusercontent.com/heidelberg-hepml/lorentz-gatr/refs/heads/main/experiments/baselines/transformer.py
2
+ from functools import partial
3
+ from typing import Optional, Tuple
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from lgatr.layers import ApplyRotaryPositionalEncoding
10
+ from lgatr.primitives.attention import scaled_dot_product_attention
11
+
12
+
13
+ def to_nd(tensor, d):
14
+ """Make tensor n-dimensional, group extra dimensions in first."""
15
+ return tensor.view(
16
+ -1, *(1,) * (max(0, d - 1 - tensor.dim())), *tensor.shape[-(d - 1) :]
17
+ )
18
+
19
+ class BaselineLayerNorm(nn.Module):
20
+ """Baseline layer norm over all dimensions except the first."""
21
+
22
+ @staticmethod
23
+ def forward(inputs: torch.Tensor) -> torch.Tensor:
24
+ """Forward pass.
25
+
26
+ Parameters
27
+ ----------
28
+ inputs : Tensor
29
+ Input data
30
+
31
+ Returns
32
+ -------
33
+ outputs : Tensor
34
+ Normalized inputs.
35
+ """
36
+ return torch.nn.functional.layer_norm(
37
+ inputs, normalized_shape=inputs.shape[-1:]
38
+ )
39
+
40
+
41
+ class MultiHeadQKVLinear(nn.Module):
42
+ """Compute queries, keys, and values via multi-head attention.
43
+
44
+ Parameters
45
+ ----------
46
+ in_channels : int
47
+ Number of input channels.
48
+ hidden_channels : int
49
+ Number of hidden channels = size of query, key, and value.
50
+ num_heads : int
51
+ Number of attention heads.
52
+ """
53
+
54
+ def __init__(self, in_channels, hidden_channels, num_heads):
55
+ super().__init__()
56
+ self.num_heads = num_heads
57
+ self.linear = nn.Linear(in_channels, 3 * hidden_channels * num_heads)
58
+
59
+ def forward(self, inputs):
60
+ """Forward pass.
61
+
62
+ Returns
63
+ -------
64
+ q : Tensor
65
+ Queries
66
+ k : Tensor
67
+ Keys
68
+ v : Tensor
69
+ Values
70
+ """
71
+ qkv = self.linear(inputs) # (..., num_items, 3 * hidden_channels * num_heads)
72
+ q, k, v = rearrange(
73
+ qkv,
74
+ "... items (qkv hidden_channels num_heads) -> qkv ... num_heads items hidden_channels",
75
+ num_heads=self.num_heads,
76
+ qkv=3,
77
+ )
78
+ return q, k, v
79
+
80
+
81
+ class MultiQueryQKVLinear(nn.Module):
82
+ """Compute queries, keys, and values via multi-query attention.
83
+
84
+ Parameters
85
+ ----------
86
+ in_channels : int
87
+ Number of input channels.
88
+ hidden_channels : int
89
+ Number of hidden channels = size of query, key, and value.
90
+ num_heads : int
91
+ Number of attention heads.
92
+ """
93
+
94
+ def __init__(self, in_channels, hidden_channels, num_heads):
95
+ super().__init__()
96
+ self.num_heads = num_heads
97
+ self.q_linear = nn.Linear(in_channels, hidden_channels * num_heads)
98
+ self.k_linear = nn.Linear(in_channels, hidden_channels)
99
+ self.v_linear = nn.Linear(in_channels, hidden_channels)
100
+
101
+ def forward(self, inputs):
102
+ """Forward pass.
103
+
104
+ Parameters
105
+ ----------
106
+ inputs : Tensor
107
+ Input data
108
+
109
+ Returns
110
+ -------
111
+ q : Tensor
112
+ Queries
113
+ k : Tensor
114
+ Keys
115
+ v : Tensor
116
+ Values
117
+ """
118
+ q = rearrange(
119
+ self.q_linear(inputs),
120
+ "... items (hidden_channels num_heads) -> ... num_heads items hidden_channels",
121
+ num_heads=self.num_heads,
122
+ )
123
+ k = self.k_linear(inputs)[
124
+ ..., None, :, :
125
+ ] # (..., head=1, item, hidden_channels)
126
+ v = self.v_linear(inputs)[..., None, :, :]
127
+ return q, k, v
128
+
129
+
130
+ class BaselineSelfAttention(nn.Module):
131
+ """Baseline self-attention layer.
132
+
133
+ Parameters
134
+ ----------
135
+ in_channels : int
136
+ Number of input channels.
137
+ out_channels : int
138
+ Number of input channels.
139
+ hidden_channels : int
140
+ Number of hidden channels = size of query, key, and value.
141
+ num_heads : int
142
+ Number of attention heads.
143
+ pos_encoding : bool
144
+ Whether to apply rotary positional embeddings along the item dimension to the scalar keys
145
+ and queries.
146
+ pos_enc_base : int
147
+ Maximum frequency used in positional encodings. (The minimum frequency is always 1.)
148
+ multi_query : bool
149
+ Use multi-query attention instead of multi-head attention.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ in_channels: int,
155
+ out_channels: int,
156
+ hidden_channels: int,
157
+ num_heads: int = 8,
158
+ pos_encoding: bool = False,
159
+ pos_enc_base: int = 4096,
160
+ multi_query: bool = True,
161
+ dropout_prob=None,
162
+ ) -> None:
163
+ super().__init__()
164
+
165
+ # Store settings
166
+ self.num_heads = num_heads
167
+ self.hidden_channels = hidden_channels
168
+
169
+ # Linear maps
170
+ qkv_class = MultiQueryQKVLinear if multi_query else MultiHeadQKVLinear
171
+ self.qkv_linear = qkv_class(in_channels, hidden_channels, num_heads)
172
+ self.out_linear = nn.Linear(hidden_channels * num_heads, out_channels)
173
+
174
+ # Optional positional encoding
175
+ if pos_encoding:
176
+ self.pos_encoding = ApplyRotaryPositionalEncoding(
177
+ hidden_channels, item_dim=-2, base=pos_enc_base
178
+ )
179
+ else:
180
+ self.pos_encoding = None
181
+
182
+ if dropout_prob is not None:
183
+ self.dropout = nn.Dropout(dropout_prob)
184
+ else:
185
+ self.dropout = None
186
+
187
+ def forward(
188
+ self,
189
+ inputs: torch.Tensor,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ is_causal: bool = False,
192
+ ) -> torch.Tensor:
193
+ """Forward pass.
194
+
195
+ Parameters
196
+ ----------
197
+ inputs : Tensor
198
+ Input data
199
+ attention_mask : None or Tensor or xformers.ops.AttentionBias
200
+ Optional attention mask
201
+
202
+ Returns
203
+ -------
204
+ outputs : Tensor
205
+ Outputs
206
+ """
207
+ q, k, v = self.qkv_linear(
208
+ inputs
209
+ ) # each: (..., num_heads, num_items, num_channels, 16)
210
+ # Rotary positional encoding
211
+ if self.pos_encoding is not None:
212
+ q = self.pos_encoding(q)
213
+ k = self.pos_encoding(k)
214
+
215
+ # Attention layer
216
+ h = self._attend(q, k, v, attention_mask, is_causal=is_causal)
217
+
218
+ # Concatenate heads and transform linearly
219
+ h = rearrange(
220
+ h,
221
+ "... num_heads num_items hidden_channels -> ... num_items (num_heads hidden_channels)",
222
+ )
223
+ outputs = self.out_linear(h) # (..., num_items, out_channels)
224
+
225
+ if self.dropout is not None:
226
+ outputs = self.dropout(outputs)
227
+
228
+ return outputs
229
+
230
+ @staticmethod
231
+ def _attend(q, k, v, attention_mask=None, is_causal=False):
232
+ """Scaled dot-product attention."""
233
+
234
+ # Add batch dimension if needed
235
+ bh_shape = q.shape[:-2]
236
+ q = to_nd(q, 4)
237
+ k = to_nd(k, 4)
238
+ v = to_nd(v, 4)
239
+
240
+ # SDPA
241
+ outputs = scaled_dot_product_attention(
242
+ q.contiguous(),
243
+ k.expand_as(q).contiguous(),
244
+ v.expand_as(q).contiguous(),
245
+ attn_mask=attention_mask,
246
+ is_causal=is_causal,
247
+ )
248
+
249
+ # Return batch dimensions to inputs
250
+ outputs = outputs.view(*bh_shape, *outputs.shape[-2:])
251
+
252
+ return outputs
253
+
254
+
255
+ class BaselineTransformerBlock(nn.Module):
256
+ """Baseline transformer block.
257
+
258
+ Inputs are first processed by a block consisting of LayerNorm, multi-head self-attention, and
259
+ residual connection. Then the data is processed by a block consisting of another LayerNorm, an
260
+ item-wise two-layer MLP with GeLU activations, and another residual connection.
261
+
262
+ Parameters
263
+ ----------
264
+ channels : int
265
+ Number of input and output channels.
266
+ num_heads : int
267
+ Number of attention heads.
268
+ pos_encoding : bool
269
+ Whether to apply rotary positional embeddings along the item dimension to the scalar keys
270
+ and queries.
271
+ pos_encoding_base : int
272
+ Maximum frequency used in positional encodings. (The minimum frequency is always 1.)
273
+ increase_hidden_channels : int
274
+ Factor by which the key, query, and value size is increased over the default value of
275
+ hidden_channels / num_heads.
276
+ multi_query : bool
277
+ Use multi-query attention instead of multi-head attention.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ channels,
283
+ num_heads: int = 8,
284
+ pos_encoding: bool = False,
285
+ pos_encoding_base: int = 4096,
286
+ increase_hidden_channels=1,
287
+ multi_query: bool = True,
288
+ dropout_prob=None,
289
+ ) -> None:
290
+ super().__init__()
291
+
292
+ self.norm = BaselineLayerNorm()
293
+
294
+ # When using positional encoding, the number of scalar hidden channels needs to be even.
295
+ # It also should not be too small.
296
+ hidden_channels = channels // num_heads * increase_hidden_channels
297
+ if pos_encoding:
298
+ hidden_channels = (hidden_channels + 1) // 2 * 2
299
+ hidden_channels = max(hidden_channels, 16)
300
+
301
+ self.attention = BaselineSelfAttention(
302
+ channels,
303
+ channels,
304
+ hidden_channels,
305
+ num_heads=num_heads,
306
+ pos_encoding=pos_encoding,
307
+ pos_enc_base=pos_encoding_base,
308
+ multi_query=multi_query,
309
+ dropout_prob=dropout_prob,
310
+ )
311
+
312
+ self.mlp = nn.Sequential(
313
+ nn.Linear(channels, 2 * channels),
314
+ nn.Dropout(dropout_prob) if dropout_prob is not None else nn.Identity(),
315
+ nn.GELU(),
316
+ nn.Linear(2 * channels, channels),
317
+ nn.Dropout(dropout_prob) if dropout_prob is not None else nn.Identity(),
318
+ )
319
+
320
+ def forward(
321
+ self, inputs: torch.Tensor, attention_mask=None, is_causal=False
322
+ ) -> torch.Tensor:
323
+ """Forward pass.
324
+
325
+ Parameters
326
+ ----------
327
+ inputs : Tensor
328
+ Input data
329
+ attention_mask : None or Tensor or xformers.ops.AttentionBias
330
+ Optional attention mask
331
+
332
+ Returns
333
+ -------
334
+ outputs : Tensor
335
+ Outputs
336
+ """
337
+
338
+ # Residual attention
339
+ h = self.norm(inputs)
340
+ h = self.attention(h, attention_mask=attention_mask, is_causal=is_causal)
341
+ outputs = inputs + h
342
+
343
+ # Residual MLP
344
+ h = self.norm(outputs)
345
+ h = self.mlp(h)
346
+ outputs = outputs + h
347
+
348
+ return outputs
349
+
350
+
351
+ class Transformer(nn.Module):
352
+ """Baseline transformer.
353
+
354
+ Combines num_blocks transformer blocks, each consisting of multi-head self-attention layers, an
355
+ MLP, residual connections, and normalization layers.
356
+
357
+ Parameters
358
+ ----------
359
+ in_channels : int
360
+ Number of input channels.
361
+ out_channels : int
362
+ Number of output channels.
363
+ hidden_channels : int
364
+ Number of hidden channels.
365
+ num_blocks : int
366
+ Number of transformer blocks.
367
+ num_heads : int
368
+ Number of attention heads.
369
+ pos_encoding : bool
370
+ Whether to apply rotary positional embeddings along the item dimension to the scalar keys
371
+ and queries.
372
+ pos_encoding_base : int
373
+ Maximum frequency used in positional encodings. (The minimum frequency is always 1.)
374
+ increase_hidden_channels : int
375
+ Factor by which the key, query, and value size is increased over the default value of
376
+ hidden_channels / num_heads.
377
+ multi_query : bool
378
+ Use multi-query attention instead of multi-head attention.
379
+ """
380
+
381
+ def __init__(
382
+ self,
383
+ in_channels: int,
384
+ out_channels: int,
385
+ hidden_channels: int,
386
+ num_blocks: int = 10,
387
+ num_heads: int = 8,
388
+ pos_encoding: bool = False,
389
+ pos_encoding_base: int = 4096,
390
+ checkpoint_blocks: bool = False,
391
+ increase_hidden_channels=1,
392
+ multi_query: bool = False,
393
+ dropout_prob=None,
394
+ ) -> None:
395
+ super().__init__()
396
+ self.checkpoint_blocks = checkpoint_blocks
397
+ self.linear_in = nn.Linear(in_channels, hidden_channels)
398
+ self.blocks = nn.ModuleList(
399
+ [
400
+ BaselineTransformerBlock(
401
+ hidden_channels,
402
+ num_heads=num_heads,
403
+ pos_encoding=pos_encoding,
404
+ pos_encoding_base=pos_encoding_base,
405
+ increase_hidden_channels=increase_hidden_channels,
406
+ multi_query=multi_query,
407
+ dropout_prob=dropout_prob,
408
+ )
409
+ for _ in range(num_blocks)
410
+ ]
411
+ )
412
+ self.linear_out = nn.Linear(hidden_channels, out_channels)
413
+
414
+ def forward(
415
+ self, inputs: torch.Tensor, attention_mask=None, is_causal=False
416
+ ) -> torch.Tensor:
417
+ """Forward pass.
418
+
419
+ Parameters
420
+ ----------
421
+ inputs : Tensor with shape (..., num_items, num_channels)
422
+ Input data
423
+ attention_mask : None or Tensor or xformers.ops.AttentionBias
424
+ Optional attention mask
425
+ is_causal: bool
426
+
427
+ Returns
428
+ -------
429
+ outputs : Tensor with shape (..., num_items, num_channels)
430
+ Outputs
431
+ """
432
+ h = self.linear_in(inputs)
433
+ for block in self.blocks:
434
+ if self.checkpoint_blocks:
435
+ fn = partial(block, attention_mask=attention_mask, is_causal=is_causal)
436
+ h = checkpoint(fn, h)
437
+ else:
438
+ h = block(h, attention_mask=attention_mask, is_causal=is_causal)
439
+ outputs = self.linear_out(h)
440
+ return outputs
441
+
442
+
443
+ class AxialTransformer(nn.Module):
444
+ """Baseline axial transformer for data with two token dimensions.
445
+
446
+ Combines num_blocks transformer blocks, each consisting of multi-head self-attention layers, an
447
+ MLP, residual connections, and normalization layers.
448
+
449
+ Assumes input data with shape `(..., num_items_1, num_items_2, num_channels, [16])`.
450
+
451
+ The first, third, fifth, ... block computes attention over the `items_2` axis. The other blocks
452
+ compute attention over the `items_1` axis. Positional encoding can be specified separately for
453
+ both axes.
454
+
455
+ Parameters
456
+ ----------
457
+ in_channels : int
458
+ Number of input channels.
459
+ out_channels : int
460
+ Number of output channels.
461
+ hidden_channels : int
462
+ Number of hidden channels.
463
+ num_blocks : int
464
+ Number of transformer blocks.
465
+ num_heads : int
466
+ Number of attention heads.
467
+ pos_encodings : tuple of bool
468
+ Whether to apply rotary positional embeddings along the item dimensions to the scalar keys
469
+ and queries.
470
+ pos_encoding_base : int
471
+ Maximum frequency used in positional encodings. (The minimum frequency is always 1.)
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ in_channels: int,
477
+ out_channels: int,
478
+ hidden_channels: int,
479
+ num_blocks: int = 20,
480
+ num_heads: int = 8,
481
+ pos_encodings: Tuple[bool, bool] = (False, False),
482
+ pos_encoding_base: int = 4096,
483
+ ) -> None:
484
+ super().__init__()
485
+ self.linear_in = nn.Linear(in_channels, hidden_channels)
486
+ self.blocks = nn.ModuleList(
487
+ [
488
+ BaselineTransformerBlock(
489
+ hidden_channels,
490
+ num_heads=num_heads,
491
+ pos_encoding=pos_encodings[(block + 1) % 2],
492
+ pos_encoding_base=pos_encoding_base,
493
+ )
494
+ for block in range(num_blocks)
495
+ ]
496
+ )
497
+ self.linear_out = nn.Linear(hidden_channels, out_channels)
498
+
499
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
500
+ """Forward pass.
501
+
502
+ Parameters
503
+ ----------
504
+ inputs : Tensor with shape (..., num_items1, num_items2, num_channels)
505
+ Input data
506
+
507
+ Returns
508
+ -------
509
+ outputs : Tensor with shape (..., num_items1, num_items2, num_channels)
510
+ Outputs
511
+ """
512
+
513
+ rearrange_pattern = "... i j c -> ... j i c"
514
+
515
+ h = self.linear_in(inputs)
516
+
517
+ for i, block in enumerate(self.blocks):
518
+ # For first, third, ... block, we want to perform attention over the first token
519
+ # dimension. We implement this by transposing the two item dimensions.
520
+ if i % 2 == 1:
521
+ h = rearrange(h, rearrange_pattern)
522
+
523
+ h = block(h)
524
+
525
+ # Transposing back to standard axis order
526
+ if i % 2 == 1:
527
+ h = rearrange(h, rearrange_pattern)
528
+
529
+ outputs = self.linear_out(h)
530
+
531
+ return outputs
src/1models/transformer/transformer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models.transformer.tr_blocks import Transformer
2
+ import torch
3
+ import torch.nn as nn
4
+ from xformers.ops.fmha import BlockDiagonalMask
5
+ from torch_scatter import scatter_max, scatter_add, scatter_mean
6
+ import numpy as np
7
+
8
+
9
+ class TransformerModel(torch.nn.Module):
10
+ def __init__(self, n_scalars, n_scalars_out, n_blocks, n_heads, internal_dim, obj_score, global_features_copy=False):
11
+ super().__init__()
12
+ self.n_scalars = n_scalars
13
+ self.input_dim = n_scalars + 3
14
+ if obj_score:
15
+ self.input_dim += 1
16
+ self.output_dim = 3
17
+ self.obj_score = obj_score
18
+ #internal_dim = 128
19
+ #self.custom_decoder = nn.Linear(internal_dim, self.output_dim)
20
+ #n_heads = 4
21
+ #self.transformer = nn.TransformerEncoder(
22
+ # nn.TransformerEncoderLayer(
23
+ # d_model=n_heads*self.input_dim,
24
+ # nhead=n_heads,
25
+ # dim_feedforward=internal_dim,
26
+ # dropout=0.1,
27
+ # activation="gelu",
28
+ # ),
29
+ # num_layers=4,
30
+ #)
31
+ if n_scalars_out > 0:
32
+ self.output_dim += 1 # betas regression
33
+ if self.obj_score:
34
+ self.output_dim = 10
35
+ self.global_features_copy = global_features_copy
36
+ self.transformer = Transformer(
37
+ in_channels=self.input_dim,
38
+ out_channels=self.output_dim,
39
+ hidden_channels=internal_dim,
40
+ num_heads=n_heads,
41
+ num_blocks=n_blocks,
42
+ )
43
+ if self.global_features_copy:
44
+ self.transformer_global_features = Transformer(
45
+ in_channels=self.input_dim,
46
+ out_channels=self.output_dim,
47
+ hidden_channels=internal_dim,
48
+ num_heads=n_heads,
49
+ num_blocks=n_blocks,
50
+ )
51
+ self.batch_norm = nn.BatchNorm1d(self.input_dim, momentum=0.1)
52
+ if self.obj_score:
53
+ factor = 1
54
+ if self.global_features_copy: factor = 2
55
+ self.final_mlp = nn.Sequential(
56
+ nn.Linear(self.output_dim*factor, 10),
57
+ nn.LeakyReLU(),
58
+ nn.Linear(10, 1),
59
+ )
60
+ #self.clustering = nn.Linear(3, self.output_dim - 1, bias=False)
61
+
62
+ def forward(self, data, data_events=None, data_events_clusters=None):
63
+ # data: instance of EventBatch
64
+ # data_events & data_events_clusters: Only relevant if --global-features-obj-score is on: data_events contains
65
+ # the "unmodified" batch where the batch indices are
66
+ if self.global_features_copy:
67
+ assert data_events is not None and data_events_clusters is not None
68
+ assert self.obj_score
69
+ inputs_v = data_events.input_vectors.float()
70
+ inputs_scalar = data_events.input_scalars.float()
71
+ assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (
72
+ self.n_scalars, inputs_scalar.shape[1])
73
+ inputs_transformer_events = torch.cat([inputs_scalar, inputs_v], dim=1)
74
+ inputs_transformer_events = inputs_transformer_events.float()
75
+ assert inputs_transformer_events.shape[1] == self.input_dim
76
+ mask_global = self.build_attention_mask(data_events.batch_idx)
77
+ x_global = inputs_transformer_events.unsqueeze(0)
78
+ x_global = self.transformer_global_features(x_global, attention_mask=mask_global)[0]
79
+ assert x_global.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x_global.shape[1])
80
+ assert x_global.shape[0] == x_global.shape[0], "Expected %d, got %d" % (
81
+ inputs_transformer_events.shape[0], x_global.shape[0])
82
+ m_global = scatter_mean(x_global, torch.tensor(data_events_clusters).to(x_global.device)+1, dim=0)[1:]
83
+ inputs_v = data.input_vectors
84
+ inputs_scalar = data.input_scalars
85
+ assert inputs_scalar.shape[1] == self.n_scalars, "Expected %d, got %d" % (self.n_scalars, inputs_scalar.shape[1])
86
+ inputs_transformer = torch.cat([inputs_scalar, inputs_v], dim=1)
87
+ inputs_transformer = inputs_transformer.float()
88
+ print("input_dim", self.input_dim, inputs_transformer.shape)
89
+ assert inputs_transformer.shape[1] == self.input_dim
90
+ mask = self.build_attention_mask(data.batch_idx)
91
+ x = inputs_transformer.unsqueeze(0)
92
+ x = self.transformer(x, attention_mask=mask)[0]
93
+ assert x.shape[1] == self.output_dim, "Expected %d, got %d" % (self.output_dim, x.shape[1])
94
+ assert x.shape[0] == inputs_transformer.shape[0], "Expected %d, got %d" % (inputs_transformer.shape[0], x.shape[0])
95
+ if not self.obj_score:
96
+ x[:, -1] = torch.sigmoid(x[:, -1])
97
+ else:
98
+ extract_from_virtual_nodes = False
99
+ if extract_from_virtual_nodes:
100
+ x = self.final_mlp(x[data.fake_nodes_idx]) # x is the raw logits
101
+ else:
102
+ m = scatter_mean(x, torch.tensor(data.batch_idx).long().to(x.device), dim=0)
103
+ assert not "fake_nodes_idx" in data.__dict__
104
+ if self.global_features_copy:
105
+ m = torch.cat([m, m_global], dim=1)
106
+ x = self.final_mlp(m).flatten()
107
+ return x
108
+
109
+ def build_attention_mask(self, batch_numbers):
110
+ return BlockDiagonalMask.from_seqlens(
111
+ torch.bincount(batch_numbers.long()).tolist()
112
+ )
113
+
114
+ def get_model(args, obj_score=False):
115
+ n_scalars_out = 8
116
+ if args.beta_type == "pt":
117
+ n_scalars_out = 0
118
+ elif args.beta_type == "pt+bc":
119
+ n_scalars_out = 1
120
+ n_scalars_in = 12
121
+ if args.no_pid:
122
+ n_scalars_in = 12-9
123
+ if obj_score:
124
+ return TransformerModel(
125
+ n_scalars=n_scalars_in,
126
+ n_scalars_out=10,
127
+ n_blocks=5,
128
+ n_heads=args.n_heads,
129
+ internal_dim=64,
130
+ obj_score=obj_score,
131
+ global_features_copy=args.global_features_obj_score
132
+ )
133
+ return TransformerModel(
134
+ n_scalars=n_scalars_in,
135
+ n_scalars_out=n_scalars_out,
136
+ n_blocks=args.num_blocks,
137
+ n_heads=args.n_heads,
138
+ internal_dim=args.internal_dim,
139
+ obj_score=obj_score
140
+ )
141
+
src/model_wrapper_gradio.py CHANGED
@@ -41,7 +41,7 @@ def inference(loss_str, train_dataset_str, input_text, input_text_quarks):
41
  args.spatial_part_only = True # LGATr
42
  args.load_model_weights = model_path
43
  args.aug_soft = True # LGATr_GP etc.
44
- args.network_config = "src/models/LGATr/lgatr.py"
45
  args.beta_type = "pt+bc"
46
  args.embed_as_vectors = False
47
  args.debug = False
 
41
  args.spatial_part_only = True # LGATr
42
  args.load_model_weights = model_path
43
  args.aug_soft = True # LGATr_GP etc.
44
+ args.network_config = "src/1models/LGATr/lgatr.py"
45
  args.beta_type = "pt+bc"
46
  args.embed_as_vectors = False
47
  args.debug = False