File size: 10,146 Bytes
a560c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Attention module library."""

import functools
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union

from flax import linen as nn
import jax
import jax.numpy as jnp
from invariant_slot_attention.modules import misc

Shape = Tuple[int]

DType = Any
Array = Any  # jnp.ndarray
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]]  # pytype: disable=not-supported-yet
ProcessorState = ArrayTree
PRNGKey = Array
NestedDict = Dict[str, Any]


class SlotAttention(nn.Module):
  """Slot Attention module.

  Note: This module uses pre-normalization by default.
  """

  num_iterations: int = 1
  qkv_size: Optional[int] = None
  mlp_size: Optional[int] = None
  epsilon: float = 1e-8
  num_heads: int = 1

  @nn.compact
  def __call__(self, slots, inputs,
               padding_mask = None,
               train = False):
    """Slot Attention module forward pass."""
    del padding_mask, train  # Unused.

    qkv_size = self.qkv_size or slots.shape[-1]
    head_dim = qkv_size // self.num_heads
    dense = functools.partial(nn.DenseGeneral,
                              axis=-1, features=(self.num_heads, head_dim),
                              use_bias=False)

    # Shared modules.
    dense_q = dense(name="general_dense_q_0")
    layernorm_q = nn.LayerNorm()
    inverted_attention = InvertedDotProductAttention(
        norm_type="mean", multi_head=self.num_heads > 1)
    gru = misc.GRU()

    if self.mlp_size is not None:
      mlp = misc.MLP(hidden_size=self.mlp_size, layernorm="pre", residual=True)  # type: ignore

    # inputs.shape = (..., n_inputs, inputs_size).
    inputs = nn.LayerNorm()(inputs)
    # k.shape = (..., n_inputs, slot_size).
    k = dense(name="general_dense_k_0")(inputs)
    # v.shape = (..., n_inputs, slot_size).
    v = dense(name="general_dense_v_0")(inputs)

    # Multiple rounds of attention.
    for _ in range(self.num_iterations):

      # Inverted dot-product attention.
      slots_n = layernorm_q(slots)
      q = dense_q(slots_n)  # q.shape = (..., n_inputs, slot_size).
      updates = inverted_attention(query=q, key=k, value=v)

      # Recurrent update.
      slots = gru(slots, updates)

      # Feedforward block with pre-normalization.
      if self.mlp_size is not None:
        slots = mlp(slots)

    return slots


class InvertedDotProductAttention(nn.Module):
  """Inverted version of dot-product attention (softmax over query axis)."""

  norm_type: Optional[str] = "mean"  # mean, layernorm, or None
  multi_head: bool = False
  epsilon: float = 1e-8
  dtype: DType = jnp.float32
  precision: Optional[jax.lax.Precision] = None
  return_attn_weights: bool = False

  @nn.compact
  def __call__(self, query, key, value,
               train = False):
    """Computes inverted dot-product attention.

    Args:
      query: Queries with shape of `[batch..., q_num, qk_features]`.
      key: Keys with shape of `[batch..., kv_num, qk_features]`.
      value: Values with shape of `[batch..., kv_num, v_features]`.
      train: Indicating whether we're training or evaluating.

    Returns:
      Output of shape `[batch_size..., n_queries, v_features]`
    """
    del train  # Unused.

    attn = GeneralizedDotProductAttention(
        inverted_attn=True,
        renormalize_keys=True if self.norm_type == "mean" else False,
        epsilon=self.epsilon,
        dtype=self.dtype,
        precision=self.precision,
        return_attn_weights=True)

    # Apply attention mechanism.
    output, attn = attn(query=query, key=key, value=value)

    if self.multi_head:
      # Multi-head aggregation. Equivalent to concat + dense layer.
      output = nn.DenseGeneral(features=output.shape[-1], axis=(-2, -1))(output)
    else:
      # Remove head dimension.
      output = jnp.squeeze(output, axis=-2)
      attn = jnp.squeeze(attn, axis=-3)

    if self.norm_type == "layernorm":
      output = nn.LayerNorm()(output)

    if self.return_attn_weights:
      return output, attn

    return output


class GeneralizedDotProductAttention(nn.Module):
  """Multi-head dot-product attention with customizable normalization axis.

  This module supports logging of attention weights in a variable collection.
  """

  dtype: DType = jnp.float32
  precision: Optional[jax.lax.Precision] = None
  epsilon: float = 1e-8
  inverted_attn: bool = False
  renormalize_keys: bool = False
  attn_weights_only: bool = False
  return_attn_weights: bool = False

  @nn.compact
  def __call__(self, query, key, value,
               train = False, **kwargs
               ):
    """Computes multi-head dot-product attention given query, key, and value.

    Args:
      query: Queries with shape of `[batch..., q_num, num_heads, qk_features]`.
      key: Keys with shape of `[batch..., kv_num, num_heads, qk_features]`.
      value: Values with shape of `[batch..., kv_num, num_heads, v_features]`.
      train: Indicating whether we're training or evaluating.
      **kwargs: Additional keyword arguments are required when used as attention
        function in nn.MultiHeadDotProductAttention, but they will be ignored
        here.

    Returns:
      Output of shape `[batch..., q_num, num_heads, v_features]`.
    """

    assert query.ndim == key.ndim == value.ndim, (
        "Queries, keys, and values must have the same rank.")
    assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
        "Query, key, and value batch dimensions must match.")
    assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
        "Query, key, and value num_heads dimensions must match.")
    assert key.shape[-3] == value.shape[-3], (
        "Key and value cardinality dimensions must match.")
    assert query.shape[-1] == key.shape[-1], (
        "Query and key feature dimensions must match.")

    if kwargs.get("bias") is not None:
      raise NotImplementedError(
          "Support for masked attention is not yet implemented.")

    if "dropout_rate" in kwargs:
      if kwargs["dropout_rate"] > 0.:
        raise NotImplementedError("Support for dropout is not yet implemented.")

    # Temperature normalization.
    qk_features = query.shape[-1]
    query = query / jnp.sqrt(qk_features).astype(self.dtype)

    # attn.shape = (batch..., num_heads, q_num, kv_num)
    attn = jnp.einsum("...qhd,...khd->...hqk", query, key,
                      precision=self.precision)

    if self.inverted_attn:
      attention_axis = -2  # Query axis.
    else:
      attention_axis = -1  # Key axis.

    # Softmax normalization (by default over key axis).
    attn = jax.nn.softmax(attn, axis=attention_axis).astype(self.dtype)

    # Defines intermediate for logging.
    if not train:
      self.sow("intermediates", "attn", attn)

    if self.renormalize_keys:
      # Corresponds to value aggregation via weighted mean (as opposed to sum).
      normalizer = jnp.sum(attn, axis=-1, keepdims=True) + self.epsilon
      attn = attn / normalizer

    if self.attn_weights_only:
      return attn

    # Aggregate values using a weighted sum with weights provided by `attn`.
    output = jnp.einsum(
        "...hqk,...khd->...qhd", attn, value, precision=self.precision)

    if self.return_attn_weights:
      return output, attn

    return output


class Transformer(nn.Module):
  """Transformer with multiple blocks."""

  num_heads: int
  qkv_size: int
  mlp_size: int
  num_layers: int
  pre_norm: bool = False

  @nn.compact
  def __call__(self, queries, inputs = None,
               padding_mask = None,
               train = False):
    x = queries
    for lyr in range(self.num_layers):
      x = TransformerBlock(
          num_heads=self.num_heads, qkv_size=self.qkv_size,
          mlp_size=self.mlp_size, pre_norm=self.pre_norm,
          name=f"TransformerBlock{lyr}")(  # pytype: disable=wrong-arg-types
              x, inputs, padding_mask, train)
    return x


class TransformerBlock(nn.Module):
  """Transformer decoder block."""

  num_heads: int
  qkv_size: int
  mlp_size: int
  pre_norm: bool = False

  @nn.compact
  def __call__(self, queries, inputs = None,
               padding_mask = None,
               train = False):
    del padding_mask  # Unused.
    assert queries.ndim == 3

    attention_fn = GeneralizedDotProductAttention()

    attn = functools.partial(
        nn.MultiHeadDotProductAttention,
        num_heads=self.num_heads,
        qkv_features=self.qkv_size,
        attention_fn=attention_fn)

    mlp = misc.MLP(hidden_size=self.mlp_size)  # type: ignore

    if self.pre_norm:
      # Self-attention on queries.
      x = nn.LayerNorm()(queries)
      x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train)
      x = x + queries

      # Cross-attention on inputs.
      if inputs is not None:
        assert inputs.ndim == 3
        y = nn.LayerNorm()(x)
        y = attn()(inputs_q=y, inputs_kv=inputs, deterministic=not train)
        y = y + x
      else:
        y = x

      # MLP
      z = nn.LayerNorm()(y)
      z = mlp(z, train)
      z = z + y
    else:
      # Self-attention on queries.
      x = queries
      x = attn()(inputs_q=x, inputs_kv=x, deterministic=not train)
      x = x + queries
      x = nn.LayerNorm()(x)

      # Cross-attention on inputs.
      if inputs is not None:
        assert inputs.ndim == 3
        y = attn()(inputs_q=x, inputs_kv=inputs, deterministic=not train)
        y = y + x
        y = nn.LayerNorm()(y)
      else:
        y = x

      # MLP.
      z = mlp(y, train)
      z = z + y
      z = nn.LayerNorm()(z)
    return z