Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from risk_biased.models.nn_blocks import ( | |
| SequenceEncoderLSTM, | |
| SequenceEncoderMLP, | |
| SequenceEncoderMaskedLSTM, | |
| ) | |
| from risk_biased.models.cvae_params import CVAEParams | |
| from risk_biased.models.mlp import MLP | |
| class MapEncoderNN(nn.Module): | |
| """MLP encoder neural network that encodes map objects. | |
| Args: | |
| params: dataclass defining the necessary parameters | |
| """ | |
| def __init__(self, params: CVAEParams) -> None: | |
| super().__init__() | |
| self._encoder = SequenceEncoderMLP( | |
| params.map_state_dim, | |
| params.hidden_dim, | |
| params.num_hidden_layers, | |
| params.max_size_lane, | |
| params.is_mlp_residual, | |
| ) | |
| def forward(self, map, mask_map): | |
| """Forward function encoding map object sequences of features into object features. | |
| Args: | |
| map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects | |
| mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask | |
| """ | |
| encoded_map = self._encoder(map, mask_map) | |
| return encoded_map | |