Spaces:
Runtime error
Runtime error
| import functools | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| import numpy as np | |
| from flax.linen.initializers import constant, orthogonal | |
| from typing import List, Sequence | |
| import distrax | |
| import jax | |
| from kinetix.models.actor_critic import GeneralActorCriticRNN, ScannedRNN | |
| from kinetix.render.renderer_symbolic_entity import EntityObservation | |
| from flax.linen.attention import MultiHeadDotProductAttention | |
| class Gating(nn.Module): | |
| # code taken from https://github.com/dhruvramani/Transformers-RL/blob/master/layers.py | |
| d_input: int | |
| bg: float = 0.0 | |
| def __call__(self, x, y): | |
| r = jax.nn.sigmoid(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(x)) | |
| z = jax.nn.sigmoid( | |
| nn.Dense(self.d_input, use_bias=False)(y) | |
| + nn.Dense(self.d_input, use_bias=False)(x) | |
| - self.param("gating_bias", constant(self.bg), (self.d_input,)) | |
| ) | |
| h = jnp.tanh(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(r * x)) | |
| g = (1 - z) * x + (z * h) | |
| return g | |
| class transformer_layer(nn.Module): | |
| num_heads: int | |
| out_features: int | |
| qkv_features: int | |
| gating: bool = False | |
| gating_bias: float = 0.0 | |
| def setup(self): | |
| self.attention1 = MultiHeadDotProductAttention( | |
| num_heads=self.num_heads, qkv_features=self.qkv_features, out_features=self.out_features | |
| ) | |
| self.ln1 = nn.LayerNorm() | |
| self.dense1 = nn.Dense(self.out_features) | |
| self.dense2 = nn.Dense(self.out_features) | |
| self.ln2 = nn.LayerNorm() | |
| if self.gating: | |
| self.gate1 = Gating(self.out_features, self.gating_bias) | |
| self.gate2 = Gating(self.out_features, self.gating_bias) | |
| def __call__(self, queries: jnp.ndarray, mask: jnp.ndarray): | |
| # After reading the paper, this is what I think we should do: | |
| # First layernorm, then do attention | |
| queries_n = self.ln1(queries) | |
| y = self.attention1(queries_n, mask=mask) | |
| if self.gating: # and gate | |
| y = self.gate1(queries, jax.nn.relu(y)) | |
| else: | |
| y = queries + y | |
| # Dense after norming, crucially no relu. | |
| e = self.dense1(self.ln2(y)) | |
| if self.gating: # and gate again | |
| # This may be the wrong way around | |
| e = self.gate2(y, jax.nn.relu(e)) | |
| else: | |
| e = y + e | |
| return e | |
| class Transformer(nn.Module): | |
| encoder_size: int | |
| num_heads: int | |
| qkv_features: int | |
| num_layers: int | |
| gating: bool = False | |
| gating_bias: float = 0.0 | |
| def setup(self): | |
| # self.encoder = nn.Dense(self.encoder_size) | |
| # self.positional_encoding = PositionalEncoding(self.encoder_size, max_len=self.max_len) | |
| self.tf_layers = [ | |
| transformer_layer( | |
| num_heads=self.num_heads, | |
| qkv_features=self.qkv_features, | |
| out_features=self.encoder_size, | |
| gating=self.gating, | |
| gating_bias=self.gating_bias, | |
| ) | |
| for _ in range(self.num_layers) | |
| ] | |
| self.joint_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)] | |
| self.thruster_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)] | |
| # self.pos_emb=PositionalEmbedding(self.encoder_size) | |
| def __call__( | |
| self, | |
| shape_embeddings: jnp.ndarray, | |
| shape_attention_mask, | |
| joint_embeddings, | |
| joint_mask, | |
| joint_indexes, | |
| thruster_embeddings, | |
| thruster_mask, | |
| thruster_indexes, | |
| ): | |
| # forward eval so obs is only one timestep | |
| # encoded = self.encoder(shape_embeddings) | |
| # pos_embed=self.pos_emb(jnp.arange(1+memories.shape[-3],-1,-1))[:1+memories.shape[-3]] | |
| for tf_layer, joint_layer, thruster_layer in zip(self.tf_layers, self.joint_layers, self.thruster_layers): | |
| # Do attention | |
| shape_embeddings = tf_layer(shape_embeddings, shape_attention_mask) | |
| # Joints | |
| # T, B, 2J, (2SE + JE) | |
| def do_index2(to_ind, ind): | |
| return to_ind[ind] | |
| joint_shape_embeddings = jnp.concatenate( | |
| [ | |
| do_index2(shape_embeddings, joint_indexes[..., 0]), | |
| do_index2(shape_embeddings, joint_indexes[..., 1]), | |
| joint_embeddings, | |
| ], | |
| axis=-1, | |
| ) | |
| shape_joint_entity_delta = joint_layer(joint_shape_embeddings) * joint_mask[..., None] | |
| def add2(addee, index, adder): | |
| return addee.at[index].add(adder) | |
| # Thrusters | |
| thruster_shape_embeddings = jnp.concatenate( | |
| [ | |
| do_index2(shape_embeddings, thruster_indexes), | |
| thruster_embeddings, | |
| ], | |
| axis=-1, | |
| ) | |
| shape_thruster_entity_delta = thruster_layer(thruster_shape_embeddings) * thruster_mask[..., None] | |
| shape_embeddings = add2(shape_embeddings, joint_indexes[..., 0], shape_joint_entity_delta) | |
| shape_embeddings = add2(shape_embeddings, thruster_indexes, shape_thruster_entity_delta) | |
| return shape_embeddings | |
| class ActorCriticTransformer(nn.Module): | |
| action_dim: Sequence[int] | |
| fc_layer_width: int | |
| action_mode: str | |
| hybrid_action_continuous_dim: int | |
| multi_discrete_number_of_dims_per_distribution: List[int] | |
| transformer_size: int | |
| transformer_encoder_size: int | |
| transformer_depth: int | |
| fc_layer_depth: int | |
| num_heads: int | |
| activation: str | |
| aggregate_mode: str # "dummy" or "mean" or "dummy_and_mean" | |
| full_attention_mask: bool # if true, only mask out inactives, and have everything attend to everything else | |
| add_generator_embedding: bool = False | |
| generator_embedding_number_of_timesteps: int = 10 | |
| recurrent: bool = True | |
| def __call__(self, hidden, x): | |
| if self.activation == "relu": | |
| activation = nn.relu | |
| else: | |
| activation = nn.tanh | |
| og_obs, dones = x | |
| if self.add_generator_embedding: | |
| obs = og_obs.obs | |
| else: | |
| obs = og_obs | |
| # obs._ is [T, B, N, L] | |
| # B - batch size | |
| # T - time | |
| # N - number of things | |
| # L - unembedded entity size | |
| obs: EntityObservation | |
| def _single_encoder(features, entity_id, concat=True): | |
| # assume two entity types | |
| num_to_remove = 1 if concat else 0 | |
| embedding = activation( | |
| nn.Dense( | |
| self.transformer_encoder_size - num_to_remove, | |
| kernel_init=orthogonal(np.sqrt(2)), | |
| bias_init=constant(0.0), | |
| )(features) | |
| ) | |
| if concat: | |
| id_1h = jnp.zeros((*embedding.shape[:3], 1)).at[:, :, :, entity_id].set(entity_id) | |
| return jnp.concatenate([embedding, id_1h], axis=-1) | |
| else: | |
| return embedding | |
| circle_encodings = _single_encoder(obs.circles, 0) | |
| polygon_encodings = _single_encoder(obs.polygons, 1) | |
| joint_encodings = _single_encoder(obs.joints, -1, False) | |
| thruster_encodings = _single_encoder(obs.thrusters, -1, False) | |
| # Size of this is something like (T, B, N, K) (time, batch, num_entities, embedding_size) | |
| # T, B, M, K | |
| shape_encodings = jnp.concatenate([polygon_encodings, circle_encodings], axis=2) | |
| # T, B, M | |
| shape_mask = jnp.concatenate([obs.polygon_mask, obs.circle_mask], axis=2) | |
| def mask_out_inactives(flat_active_mask, matrix_attention_mask): | |
| matrix_attention_mask = matrix_attention_mask & (flat_active_mask[:, None]) & (flat_active_mask[None, :]) | |
| return matrix_attention_mask | |
| joint_indexes = obs.joint_indexes | |
| thruster_indexes = obs.thruster_indexes | |
| if self.aggregate_mode == "dummy" or self.aggregate_mode == "dummy_and_mean": | |
| T, B, _, K = circle_encodings.shape | |
| dummy = jnp.ones((T, B, 1, K)) | |
| shape_encodings = jnp.concatenate([dummy, shape_encodings], axis=2) | |
| shape_mask = jnp.concatenate( | |
| [jnp.ones((T, B, 1), dtype=bool), shape_mask], | |
| axis=2, | |
| ) | |
| N = obs.attention_mask.shape[-1] | |
| overall_mask = ( | |
| jnp.ones((T, B, obs.attention_mask.shape[2], N + 1, N + 1), dtype=bool) | |
| .at[:, :, :, 1:, 1:] | |
| .set(obs.attention_mask) | |
| ) | |
| overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask) | |
| # To account for the dummy entity | |
| joint_indexes = joint_indexes + 1 | |
| thruster_indexes = thruster_indexes + 1 | |
| else: | |
| overall_mask = obs.attention_mask | |
| if self.full_attention_mask: | |
| overall_mask = jnp.ones(overall_mask.shape, dtype=bool) | |
| overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask) | |
| # Now do attention on these | |
| embedding = Transformer( | |
| num_layers=self.transformer_depth, | |
| num_heads=self.num_heads, | |
| qkv_features=self.transformer_size, | |
| encoder_size=self.transformer_encoder_size, | |
| gating=True, | |
| gating_bias=0.0, | |
| )( | |
| shape_encodings, | |
| jnp.repeat(overall_mask, repeats=self.num_heads // overall_mask.shape[2], axis=2), | |
| joint_encodings, | |
| obs.joint_mask, | |
| joint_indexes, | |
| thruster_encodings, | |
| obs.thruster_mask, | |
| thruster_indexes, | |
| ) # add the extra dimension for the heads | |
| if self.aggregate_mode == "mean" or self.aggregate_mode == "dummy_and_mean": | |
| embedding = jnp.mean(embedding, axis=2, where=shape_mask[..., None]) | |
| else: | |
| embedding = embedding[:, :, 0] # Take the dummy entity as the embedding of the entire scene. | |
| return GeneralActorCriticRNN( | |
| action_dim=self.action_dim, | |
| fc_layer_depth=self.fc_layer_depth, | |
| fc_layer_width=self.fc_layer_width, | |
| action_mode=self.action_mode, | |
| hybrid_action_continuous_dim=self.hybrid_action_continuous_dim, | |
| multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution, | |
| add_generator_embedding=self.add_generator_embedding, | |
| generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps, | |
| recurrent=self.recurrent, | |
| )(hidden, og_obs, embedding, dones, activation) | |