SARM Progress Prediction
Stage-aware progress prediction model for robot manipulation tasks
Model Description
SARM predicts:
- Progress: How far through the task (0.0 to 1.0)
- Stage: Which stage of the task is being executed
The model uses a transformer architecture to process sequences of RGB images and robot states.
Task: clearing_food_from_table_into_fridge Dataset: IliaLarchenko/behavior_224_rgb
Model Details
Architecture
- Type: Transformer with dual prediction heads (stage classification + progress regression)
- Model dimension: 768
- Attention heads: 12
- Transformer layers: 8
- MLP dimension: 512
- Number of stages: 100
- Number of tasks: 50
Training Details
- Checkpoint:
best_model.pt - Training step: 4800
- Epoch: unknown
- Training loss: unknown
- Validation loss: 1.0865614609792829
- Batch size: 16
- Learning rate: 0.0001
- Max sequence length: 13
Usage
Download and Load Model
from hf_model_hub import download_model_from_hub
from model import SARM
import torch
import json
# Download model and config
files = download_model_from_hub(
repo_id="YOUR_USERNAME/YOUR_REPO",
checkpoint_name="best_model.pt",
output_dir="./downloaded_model"
)
# Load config
with open(files["config"], "r") as f:
config = json.load(f)
# Create model
model_config = config["model"]
model = SARM(
d_model=model_config["d_model"],
n_heads=model_config["n_heads"],
n_layers=model_config["n_layers"],
d_mlp=model_config["d_mlp"],
num_stages=model_config["num_stages"],
d_state=model_config["d_state"],
num_tasks=model_config["num_tasks"],
)
# Load checkpoint
checkpoint = torch.load(files["checkpoint"])
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
Run Inference
# Assuming you have images and states prepared
with torch.no_grad():
stage_logits, progress = model(images, states, tasks, padding_mask)
# Get predictions for the last frame
predicted_stage = stage_logits[:, -1].argmax(dim=-1)
predicted_progress = progress[:, -1]
Training Data
This model was trained on the IliaLarchenko/behavior_224_rgb for robot manipulation tasks.
Training episodes: 90 episodes Validation episodes: 15 episodes
Intended Use
- Progress estimation for robot manipulation tasks
- Stage classification for multi-step tasks
- Adaptive window sampling for VLA training
- Task monitoring and intervention detection
Limitations
- Trained on specific tasks from BEHAVIOR dataset
- Requires RGB images (224x224) and robot state information
- Fixed sequence length input
Citation
If you use this model, please cite:
@misc{sarm-model,
author = {Your Name},
title = {SARM Progress Prediction},
year = {2025},
publisher = {HuggingFace},
url = {https://huggingface.co/YOUR_USERNAME/YOUR_REPO}
}
Training Configuration
Click to expand full training configuration
{
"metadata": {
"model_name": "SARM Progress Prediction",
"description": "Stage-aware progress prediction model for robot manipulation tasks",
"task": "clearing_food_from_table_into_fridge",
"task_number": 25,
"dataset": "IliaLarchenko/behavior_224_rgb",
"version": "1.0",
"author": "Your Name",
"tags": [
"robotics",
"progress-estimation",
"behavior-cloning"
]
},
"model": {
"d_model": 768,
"n_heads": 12,
"n_layers": 8,
"d_mlp": 512,
"num_stages": 100,
"d_state": 256,
"num_tasks": 50
},
"training": {
"max_steps": 10000,
"learning_rate": 0.0001,
"weight_decay": 0.0001,
"batch_size": 16,
"gradient_accumulation_steps": 4,
"max_grad_norm": 1.0,
"scheduler": "cosine",
"stage_loss_weight": 1.0,
"progress_loss_weight": 1.0,
"validation_steps": 100,
"save_steps": 200
},
"data": {
"max_sequence_length": 13,
"image_size": 224,
"num_workers": 10,
"val_workers": 10,
"val_samples": 500,
"train_episodes": [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90
],
"val_episodes": [
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105
],
"seed": 42
},
"logging": {
"project_name": "sarm-training",
"run_name": null,
"log_freq": 10,
"checkpoint_dir": "checkpoints_sarm_25_2"
}
}
- Downloads last month
- 4