SARM Progress Prediction

Stage-aware progress prediction model for robot manipulation tasks

Model Description

SARM predicts:

  • Progress: How far through the task (0.0 to 1.0)
  • Stage: Which stage of the task is being executed

The model uses a transformer architecture to process sequences of RGB images and robot states.

Task: clearing_food_from_table_into_fridge Dataset: IliaLarchenko/behavior_224_rgb

Model Details

Architecture

  • Type: Transformer with dual prediction heads (stage classification + progress regression)
  • Model dimension: 768
  • Attention heads: 12
  • Transformer layers: 8
  • MLP dimension: 512
  • Number of stages: 100
  • Number of tasks: 50

Training Details

  • Checkpoint: best_model.pt
  • Training step: 4800
  • Epoch: unknown
  • Training loss: unknown
  • Validation loss: 1.0865614609792829
  • Batch size: 16
  • Learning rate: 0.0001
  • Max sequence length: 13

Usage

Download and Load Model

from hf_model_hub import download_model_from_hub
from model import SARM
import torch
import json

# Download model and config
files = download_model_from_hub(
    repo_id="YOUR_USERNAME/YOUR_REPO",
    checkpoint_name="best_model.pt",
    output_dir="./downloaded_model"
)

# Load config
with open(files["config"], "r") as f:
    config = json.load(f)

# Create model
model_config = config["model"]
model = SARM(
    d_model=model_config["d_model"],
    n_heads=model_config["n_heads"],
    n_layers=model_config["n_layers"],
    d_mlp=model_config["d_mlp"],
    num_stages=model_config["num_stages"],
    d_state=model_config["d_state"],
    num_tasks=model_config["num_tasks"],
)

# Load checkpoint
checkpoint = torch.load(files["checkpoint"])
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

Run Inference

# Assuming you have images and states prepared
with torch.no_grad():
    stage_logits, progress = model(images, states, tasks, padding_mask)
    
    # Get predictions for the last frame
    predicted_stage = stage_logits[:, -1].argmax(dim=-1)
    predicted_progress = progress[:, -1]

Training Data

This model was trained on the IliaLarchenko/behavior_224_rgb for robot manipulation tasks.

Training episodes: 90 episodes Validation episodes: 15 episodes

Intended Use

  • Progress estimation for robot manipulation tasks
  • Stage classification for multi-step tasks
  • Adaptive window sampling for VLA training
  • Task monitoring and intervention detection

Limitations

  • Trained on specific tasks from BEHAVIOR dataset
  • Requires RGB images (224x224) and robot state information
  • Fixed sequence length input

Citation

If you use this model, please cite:

@misc{sarm-model,
  author = {Your Name},
  title = {SARM Progress Prediction},
  year = {2025},
  publisher = {HuggingFace},
  url = {https://huggingface.co/YOUR_USERNAME/YOUR_REPO}
}

Training Configuration

Click to expand full training configuration
{
  "metadata": {
    "model_name": "SARM Progress Prediction",
    "description": "Stage-aware progress prediction model for robot manipulation tasks",
    "task": "clearing_food_from_table_into_fridge",
    "task_number": 25,
    "dataset": "IliaLarchenko/behavior_224_rgb",
    "version": "1.0",
    "author": "Your Name",
    "tags": [
      "robotics",
      "progress-estimation",
      "behavior-cloning"
    ]
  },
  "model": {
    "d_model": 768,
    "n_heads": 12,
    "n_layers": 8,
    "d_mlp": 512,
    "num_stages": 100,
    "d_state": 256,
    "num_tasks": 50
  },
  "training": {
    "max_steps": 10000,
    "learning_rate": 0.0001,
    "weight_decay": 0.0001,
    "batch_size": 16,
    "gradient_accumulation_steps": 4,
    "max_grad_norm": 1.0,
    "scheduler": "cosine",
    "stage_loss_weight": 1.0,
    "progress_loss_weight": 1.0,
    "validation_steps": 100,
    "save_steps": 200
  },
  "data": {
    "max_sequence_length": 13,
    "image_size": 224,
    "num_workers": 10,
    "val_workers": 10,
    "val_samples": 500,
    "train_episodes": [
      1,
      2,
      3,
      4,
      5,
      6,
      7,
      8,
      9,
      10,
      11,
      12,
      13,
      14,
      15,
      16,
      17,
      18,
      19,
      20,
      21,
      22,
      23,
      24,
      25,
      26,
      27,
      28,
      29,
      30,
      31,
      32,
      33,
      34,
      35,
      36,
      37,
      38,
      39,
      40,
      41,
      42,
      43,
      44,
      45,
      46,
      47,
      48,
      49,
      50,
      51,
      52,
      53,
      54,
      55,
      56,
      57,
      58,
      59,
      60,
      61,
      62,
      63,
      64,
      65,
      66,
      67,
      68,
      69,
      70,
      71,
      72,
      73,
      74,
      75,
      76,
      77,
      78,
      79,
      80,
      81,
      82,
      83,
      84,
      85,
      86,
      87,
      88,
      89,
      90
    ],
    "val_episodes": [
      91,
      92,
      93,
      94,
      95,
      96,
      97,
      98,
      99,
      100,
      101,
      102,
      103,
      104,
      105
    ],
    "seed": 42
  },
  "logging": {
    "project_name": "sarm-training",
    "run_name": null,
    "log_freq": 10,
    "checkpoint_dir": "checkpoints_sarm_25_2"
  }
}
Downloads last month
4
Video Preview
loading