File size: 2,457 Bytes
4723517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
---
license: mit
tags:
- robotics
- imitation-learning
- octo
- pytorch
---
# Octo-Base PyTorch Model
This is the octo-base model converted to PyTorch format.
## Model Description
Octo is a generalist robot policy trained on diverse robot manipulation tasks.
- **Paper**: [Octo: An Open-Source Generalist Robot Policy](https://arxiv.org/pdf/2405.12213)
- **Original JAX Implementation**: [octo-models/octo](https://github.com/octo-models/octo)
- **Original Pytorch Implementation**: [emb-ai/octo-pytorch](https://github.com/emb-ai/octo-pytorch)
- **lil'km Implementation**: [s1lent4gnt/octo-pytorch](https://github.com/s1lent4gnt/octo-pytorch)
- **Model Size**: octo-base
## Usage
### Loading the pretrained model
```python
import torch
from safetensors.torch import load_file
import json
from octo_pytorch.model import OctoModel
from octo_pytorch.model.configuration_octo import OctoConfig
# Load config
with open('config.json', 'r') as f:
config_dict = json.load(f)
# Initialize model configuration
config = OctoConfig(model_name=config_dict['model_name'])
# Initialize model
model = OctoModel(config)
# Load weights (T5 encoder weights will be loaded automatically from HuggingFace Hub)
state_dict = load_file('model.safetensors')
model.load_state_dict(state_dict, strict=False) # strict=False because T5 weights are not in the file
```
### Alternative: Direct loading from HuggingFace Hub
```python
from octo_pytorch.model import OctoModel
# Load model directly from HuggingFace Hub
model = OctoModel.from_pretrained('lilkm/octo-base-test')
```
**Note**: The T5-base language encoder weights are not included in this upload to save space. They will be automatically downloaded from HuggingFace Hub when you initialize the model.
### Model Architecture
- **Transformer**: 12 layers, 768 dim, 12 heads
- **Vision Encoder**: Custom CNN (SmallStem16)
- **Language Encoder**: T5-Base
- **Action Head**: Diffusion policy with 4 action steps
- **Max Horizon**: 10 timesteps
- **Action Dimension**: 7
## Files
- `model.safetensors`: Model weights in safetensors format
- `config.json`: Model configuration
- `dataset_statistics.npy`: Dataset statistics used for normalization (if available)
## Citation
If you use this model, please cite:
```bibtex
@article{octo_2023,
title={Octo: An Open-Source Generalist Robot Policy},
author={Octo Model Team et al.},
journal={arXiv preprint arXiv:2405.12213},
year={2024}
}
```
|