|
--- |
|
license: mit |
|
tags: |
|
- robotics |
|
- imitation-learning |
|
- octo |
|
- pytorch |
|
--- |
|
|
|
# Octo-Small PyTorch Model |
|
|
|
This is the octo-small model converted to PyTorch format. |
|
|
|
## Model Description |
|
|
|
Octo is a generalist robot policy trained on diverse robot manipulation tasks. |
|
|
|
- **Paper**: [Octo: An Open-Source Generalist Robot Policy](https://arxiv.org/pdf/2405.12213) |
|
- **Original JAX Implementation**: [octo-models/octo](https://github.com/octo-models/octo) |
|
- **Original Pytorch Implementation**: [emb-ai/octo-pytorch](https://github.com/emb-ai/octo-pytorch) |
|
- **lil'km Implementation**: [s1lent4gnt/octo-pytorch](https://github.com/s1lent4gnt/octo-pytorch) |
|
- **Model Size**: octo-small |
|
|
|
## Usage |
|
|
|
### Loading the pretrained model |
|
|
|
```python |
|
import torch |
|
from safetensors.torch import load_file |
|
import json |
|
from octo_pytorch.model import OctoModel |
|
from octo_pytorch.model.configuration_octo import OctoConfig |
|
|
|
# Load config |
|
with open('config.json', 'r') as f: |
|
config_dict = json.load(f) |
|
|
|
# Initialize model configuration |
|
config = OctoConfig(model_name=config_dict['model_name']) |
|
|
|
# Initialize model |
|
model = OctoModel(config) |
|
|
|
# Load weights (T5 encoder weights will be loaded automatically from HuggingFace Hub) |
|
state_dict = load_file('model.safetensors') |
|
model.load_state_dict(state_dict, strict=False) # strict=False because T5 weights are not in the file |
|
``` |
|
|
|
### Alternative: Direct loading from HuggingFace Hub |
|
|
|
```python |
|
from octo_pytorch.model import OctoModel |
|
|
|
# Load model directly from HuggingFace Hub |
|
model = OctoModel.from_pretrained('lilkm/octo-small-test') |
|
``` |
|
|
|
**Note**: The T5-base language encoder weights are not included in this upload to save space. They will be automatically downloaded from HuggingFace Hub when you initialize the model. |
|
|
|
### Model Architecture |
|
|
|
- **Transformer**: 12 layers, 384 dim, 6 heads |
|
- **Vision Encoder**: Custom CNN (SmallStem16) |
|
- **Language Encoder**: T5-Base |
|
- **Action Head**: Diffusion policy with 4 action steps |
|
- **Max Horizon**: 10 timesteps |
|
- **Action Dimension**: 7 |
|
|
|
## Files |
|
|
|
- `model.safetensors`: Model weights in safetensors format |
|
- `config.json`: Model configuration |
|
- `dataset_statistics.npy`: Dataset statistics used for normalization (if available) |
|
|
|
## Citation |
|
|
|
If you use this model, please cite: |
|
|
|
```bibtex |
|
@article{octo_2023, |
|
title={Octo: An Open-Source Generalist Robot Policy}, |
|
author={Octo Model Team et al.}, |
|
journal={arXiv preprint arXiv:2405.12213}, |
|
year={2024} |
|
} |
|
``` |
|
|