Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from mmcv import Config | |
| class CVAEParams: | |
| """ | |
| state_dim: Dimension of the state at each time step. | |
| map_state_dim: Dimension of the map point features at each position. | |
| num_steps: Number of time steps in the past trajectory input. | |
| num_steps_future: Number of time steps in the future trajectory output. | |
| latent_dim: Dimension of the latent space | |
| hidden_dim: Dimension of the hidden layers | |
| num_hidden_layers: Number of layers for each model, (encoder, decoder) | |
| is_mlp_residual: Set to True to add linear transformation of the input to output of the MLP | |
| interaction_type: Wether to use MCG, MAB, or MHB to handle interactions | |
| num_attention_heads: Number of attention heads to use in MHA blocks | |
| mcg_dim_expansion: Dimension expansion factor for the MCG global interaction space | |
| mcg_num_layers: Number of layers for the MLP MCG blocks | |
| num_blocks: Number of interaction blocks to use | |
| sequence_encoder_type: Type of sequence encoder maskedLSTM, LSTM, or MLP | |
| sequence_decoder_type: Type of sequence decoder maskedLSTM, LSTM, or MLP | |
| condition_on_ego_future: Wether to condition the biasing with the ego future or only the ego past | |
| latent_regularization: Weight of the latent regularization loss | |
| """ | |
| dt: float | |
| state_dim: int | |
| dynamic_state_dim: int | |
| map_state_dim: int | |
| max_size_lane: int | |
| num_steps: int | |
| num_steps_future: int | |
| latent_dim: int | |
| hidden_dim: int | |
| num_hidden_layers: int | |
| is_mlp_residual: bool | |
| interaction_type: int | |
| num_attention_heads: int | |
| mcg_dim_expansion: int | |
| mcg_num_layers: int | |
| num_blocks: int | |
| sequence_encoder_type: str | |
| sequence_decoder_type: str | |
| condition_on_ego_future: bool | |
| latent_regularization: float | |
| risk_assymetry_factor: float | |
| num_vq: int | |
| latent_distribution: str | |
| def from_config(cfg: Config): | |
| return CVAEParams( | |
| dt=cfg.dt, | |
| state_dim=cfg.state_dim, | |
| dynamic_state_dim=cfg.dynamic_state_dim, | |
| map_state_dim=cfg.map_state_dim, | |
| max_size_lane=cfg.max_size_lane, | |
| num_steps=cfg.num_steps, | |
| num_steps_future=cfg.num_steps_future, | |
| latent_dim=cfg.latent_dim, | |
| hidden_dim=cfg.hidden_dim, | |
| num_hidden_layers=cfg.num_hidden_layers, | |
| is_mlp_residual=cfg.is_mlp_residual, | |
| interaction_type=cfg.interaction_type, | |
| mcg_dim_expansion=cfg.mcg_dim_expansion, | |
| mcg_num_layers=cfg.mcg_num_layers, | |
| num_blocks=cfg.num_blocks, | |
| num_attention_heads=cfg.num_attention_heads, | |
| sequence_encoder_type=cfg.sequence_encoder_type, | |
| sequence_decoder_type=cfg.sequence_decoder_type, | |
| condition_on_ego_future=cfg.condition_on_ego_future, | |
| latent_regularization=cfg.latent_regularization, | |
| risk_assymetry_factor=cfg.risk_assymetry_factor, | |
| num_vq=cfg.num_vq, | |
| latent_distribution=cfg.latent_distribution, | |
| ) | |