roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run this command to interactively debug:
PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py
Adapted from:
https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py
"""
import json
import os
import pickle
import random
import traceback
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
import imageio
import numpy as np
import torch
from decord import VideoReader, cpu
from einops import rearrange
from torch.utils.data import Dataset
from torchvision import transforms as T
from tqdm import tqdm
from cosmos_predict1.diffusion.training.datasets.dataset_utils import (
Resize_Preprocess,
ToTensorVideo,
euler2rotm,
rotm2euler,
)
class Dataset_3D(Dataset):
def __init__(
self,
train_annotation_path,
val_annotation_path,
test_annotation_path,
video_path,
sequence_interval,
num_frames,
cam_ids,
accumulate_action,
video_size,
val_start_frame_interval,
debug=False,
normalize=False,
pre_encode=False,
do_evaluate=False,
load_t5_embeddings=False,
load_action=True,
mode="train",
):
"""Dataset class for loading 3D robot action-conditional data.
This dataset loads robot trajectories consisting of RGB video frames, robot states (arm positions and gripper states),
and computes relative actions between consecutive frames.
Args:
train_annotation_path (str): Path to training annotation files
val_annotation_path (str): Path to validation annotation files
test_annotation_path (str): Path to test annotation files
video_path (str): Base path to video files
sequence_interval (int): Interval between sampled frames in a sequence
num_frames (int): Number of frames to load per sequence
cam_ids (list): List of camera IDs to sample from
accumulate_action (bool): Whether to accumulate actions relative to first frame
video_size (list): Target size [H,W] for video frames
val_start_frame_interval (int): Frame sampling interval for validation/test
debug (bool, optional): If True, only loads subset of data. Defaults to False.
normalize (bool, optional): Whether to normalize video frames. Defaults to False.
pre_encode (bool, optional): Whether to pre-encode video frames. Defaults to False.
do_evaluate (bool, optional): Whether in evaluation mode. Defaults to False.
load_t5_embeddings (bool, optional): Whether to load T5 embeddings. Defaults to False.
load_action (bool, optional): Whether to load actions. Defaults to True.
mode (str, optional): Dataset mode - 'train', 'val' or 'test'. Defaults to 'train'.
The dataset loads robot trajectories and computes:
- RGB video frames from specified camera views
- Robot arm states (xyz position + euler angles)
- Gripper states (binary open/closed)
- Relative actions between consecutive frames
Actions are computed as relative transforms between frames:
- Translation: xyz offset in previous frame's coordinate frame
- Rotation: euler angles of relative rotation
- Gripper: binary gripper state
Returns dict with:
- video: RGB frames tensor [T,C,H,W]
- action: Action tensor [T-1,7]
- video_name: Dict with episode/frame metadata
- latent: Pre-encoded video features if pre_encode=True
"""
super().__init__()
if mode == "train":
self.data_path = train_annotation_path
self.start_frame_interval = 1
elif mode == "val":
self.data_path = val_annotation_path
self.start_frame_interval = val_start_frame_interval
elif mode == "test":
self.data_path = test_annotation_path
self.start_frame_interval = val_start_frame_interval
self.video_path = video_path
self.sequence_interval = sequence_interval
self.mode = mode
self.sequence_length = num_frames
self.normalize = normalize
self.pre_encode = pre_encode
self.load_t5_embeddings = load_t5_embeddings
self.load_action = load_action
self.cam_ids = cam_ids
self.accumulate_action = accumulate_action
self.action_dim = 7 # ee xyz (3) + ee euler (3) + gripper(1)
self.c_act_scaler = [20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 1.0]
self.c_act_scaler = np.array(self.c_act_scaler, dtype=float)
self.ann_files = self._init_anns(self.data_path)
self.samples = self._init_sequences(self.ann_files)
self.samples = sorted(self.samples, key=lambda x: (x["ann_file"], x["frame_ids"][0]))
if debug and not do_evaluate:
self.samples = self.samples[0:10]
self.wrong_number = 0
self.transform = T.Compose([T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)])
self.training = False
self.preprocess = T.Compose(
[
ToTensorVideo(),
Resize_Preprocess(tuple(video_size)), # 288 512
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
self.not_norm_preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))])
def __str__(self):
return f"{len(self.ann_files)} samples from {self.data_path}"
def _init_anns(self, data_dir):
ann_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")]
return ann_files
def _init_sequences(self, ann_files):
samples = []
with ThreadPoolExecutor(32) as executor:
future_to_ann_file = {
executor.submit(self._load_and_process_ann_file, ann_file): ann_file for ann_file in ann_files
}
for future in tqdm(as_completed(future_to_ann_file), total=len(ann_files)):
samples.extend(future.result())
return samples
def _load_and_process_ann_file(self, ann_file):
samples = []
with open(ann_file, "r") as f:
ann = json.load(f)
n_frames = len(ann["state"])
for frame_i in range(0, n_frames, self.start_frame_interval):
sample = dict()
sample["ann_file"] = ann_file
sample["frame_ids"] = []
curr_frame_i = frame_i
while True:
if curr_frame_i > (n_frames - 1):
break
sample["frame_ids"].append(curr_frame_i)
if len(sample["frame_ids"]) == self.sequence_length:
break
curr_frame_i += self.sequence_interval
# make sure there are sequence_length number of frames
if len(sample["frame_ids"]) == self.sequence_length:
samples.append(sample)
return samples
def __len__(self):
return len(self.samples)
def _load_video(self, video_path, frame_ids):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
assert (np.array(frame_ids) < len(vr)).all()
assert (np.array(frame_ids) >= 0).all()
vr.seek(0)
frame_data = vr.get_batch(frame_ids).asnumpy()
return frame_data
def _get_frames(self, label, frame_ids, cam_id, pre_encode):
if pre_encode:
raise NotImplementedError("Pre-encoded videos are not supported for this dataset.")
else:
video_path = label["videos"][cam_id]["video_path"]
video_path = os.path.join(self.video_path, video_path)
frames = self._load_video(video_path, frame_ids)
frames = frames.astype(np.uint8)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w)
def printvideo(videos, filename):
t_videos = rearrange(videos, "f c h w -> f h w c")
t_videos = (
((t_videos / 2.0 + 0.5).clamp(0, 1) * 255).detach().to(dtype=torch.uint8).cpu().contiguous().numpy()
)
print(t_videos.shape)
writer = imageio.get_writer(filename, fps=4) # fps 是帧率
for frame in t_videos:
writer.append_data(frame) # 1 4 13 23 # fp16 24 76 456 688
if self.normalize:
frames = self.preprocess(frames)
else:
frames = self.not_norm_preprocess(frames)
frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8)
return frames
def _get_obs(self, label, frame_ids, cam_id, pre_encode):
if cam_id is None:
temp_cam_id = random.choice(self.cam_ids)
else:
temp_cam_id = cam_id
frames = self._get_frames(label, frame_ids, cam_id=temp_cam_id, pre_encode=pre_encode)
return frames, temp_cam_id
def _get_robot_states(self, label, frame_ids):
all_states = np.array(label["state"])
all_cont_gripper_states = np.array(label["continuous_gripper_state"])
states = all_states[frame_ids]
cont_gripper_states = all_cont_gripper_states[frame_ids]
arm_states = states[:, :6]
assert arm_states.shape[0] == self.sequence_length
assert cont_gripper_states.shape[0] == self.sequence_length
return arm_states, cont_gripper_states
def _get_all_robot_states(self, label, frame_ids):
all_states = np.array(label["state"])
all_cont_gripper_states = np.array(label["continuous_gripper_state"])
states = all_states[frame_ids]
cont_gripper_states = all_cont_gripper_states[frame_ids]
arm_states = states[:, :6]
return arm_states, cont_gripper_states
def _get_all_actions(self, arm_states, gripper_states, accumulate_action):
action_num = arm_states.shape[0] - 1
action = np.zeros((action_num, self.action_dim))
if accumulate_action:
first_xyz = arm_states[0, 0:3]
first_rpy = arm_states[0, 3:6]
first_rotm = euler2rotm(first_rpy)
for k in range(1, action_num + 1):
curr_xyz = arm_states[k, 0:3]
curr_rpy = arm_states[k, 3:6]
curr_gripper = gripper_states[k]
curr_rotm = euler2rotm(curr_rpy)
rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz)
rel_rotm = first_rotm.T @ curr_rotm
rel_rpy = rotm2euler(rel_rotm)
action[k - 1, 0:3] = rel_xyz
action[k - 1, 3:6] = rel_rpy
action[k - 1, 6] = curr_gripper
else:
for k in range(1, action_num + 1):
prev_xyz = arm_states[k - 1, 0:3]
prev_rpy = arm_states[k - 1, 3:6]
prev_rotm = euler2rotm(prev_rpy)
curr_xyz = arm_states[k, 0:3]
curr_rpy = arm_states[k, 3:6]
curr_gripper = gripper_states[k]
curr_rotm = euler2rotm(curr_rpy)
rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz)
rel_rotm = prev_rotm.T @ curr_rotm
rel_rpy = rotm2euler(rel_rotm)
action[k - 1, 0:3] = rel_xyz
action[k - 1, 3:6] = rel_rpy
action[k - 1, 6] = curr_gripper
return torch.from_numpy(action) # (l - 1, act_dim)
def _get_actions(self, arm_states, gripper_states, accumulate_action):
action = np.zeros((self.sequence_length - 1, self.action_dim))
if accumulate_action:
first_xyz = arm_states[0, 0:3]
first_rpy = arm_states[0, 3:6]
first_rotm = euler2rotm(first_rpy)
for k in range(1, self.sequence_length):
curr_xyz = arm_states[k, 0:3]
curr_rpy = arm_states[k, 3:6]
curr_gripper = gripper_states[k]
curr_rotm = euler2rotm(curr_rpy)
rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz)
rel_rotm = first_rotm.T @ curr_rotm
rel_rpy = rotm2euler(rel_rotm)
action[k - 1, 0:3] = rel_xyz
action[k - 1, 3:6] = rel_rpy
action[k - 1, 6] = curr_gripper
else:
for k in range(1, self.sequence_length):
prev_xyz = arm_states[k - 1, 0:3]
prev_rpy = arm_states[k - 1, 3:6]
prev_rotm = euler2rotm(prev_rpy)
curr_xyz = arm_states[k, 0:3]
curr_rpy = arm_states[k, 3:6]
curr_gripper = gripper_states[k]
curr_rotm = euler2rotm(curr_rpy)
rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz)
rel_rotm = prev_rotm.T @ curr_rotm
rel_rpy = rotm2euler(rel_rotm)
action[k - 1, 0:3] = rel_xyz
action[k - 1, 3:6] = rel_rpy
action[k - 1, 6] = curr_gripper
return torch.from_numpy(action) # (l - 1, act_dim)
def __getitem__(self, index, cam_id=None, return_video=False):
if self.mode != "train":
np.random.seed(index)
random.seed(index)
try:
sample = self.samples[index]
ann_file = sample["ann_file"]
frame_ids = sample["frame_ids"]
with open(ann_file, "r") as f:
label = json.load(f)
arm_states, gripper_states = self._get_robot_states(label, frame_ids)
actions = self._get_actions(arm_states, gripper_states, self.accumulate_action)
actions *= self.c_act_scaler
data = dict()
if self.load_action:
data["action"] = actions.float()
if self.pre_encode:
raise NotImplementedError("Pre-encoded videos are not supported for this dataset.")
else:
video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False)
video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W]
data["video"] = video.to(dtype=torch.uint8)
data["annotation_file"] = ann_file
# NOTE: __key__ is used to uniquely identify the sample, required for callback functions
if "episode_id" in label:
data["__key__"] = label["episode_id"]
else:
data["__key__"] = label["original_path"]
# Just add these to fit the interface
if self.load_t5_embeddings:
t5_embedding_path = ann_file.replace(".json", ".pickle")
with open(t5_embedding_path, "rb") as f:
data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0])
else:
data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16)
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64)
data["fps"] = 4
data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter?
data["num_frames"] = self.sequence_length
data["padding_mask"] = torch.zeros(1, 256, 256)
return data
except Exception:
warnings.warn(
f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped "
f"(by randomly sampling another sample in the same dataset)."
)
warnings.warn("FULL TRACEBACK:")
warnings.warn(traceback.format_exc())
self.wrong_number += 1
print(self.wrong_number)
return self[np.random.randint(len(self.samples))]
if __name__ == "__main__":
dataset = Dataset_3D(
train_annotation_path="datasets/bridge/annotation/train",
val_annotation_path="datasets/bridge/annotation/val",
test_annotation_path="datasets/bridge/annotation/test",
video_path="datasets/bridge/",
sequence_interval=1,
num_frames=2,
cam_ids=[0],
accumulate_action=False,
video_size=[256, 320],
val_start_frame_interval=1,
mode="train",
load_t5_embeddings=True,
)
indices = [0, 13, 200, -1]
for idx in indices:
print(
(
f"{idx=} "
f"{dataset[idx]['video'].sum()=}\n"
f"{dataset[idx]['video'].shape=}\n"
f"{dataset[idx]['video_name']=}\n"
f"{dataset[idx]['action'].sum()=}\n"
"---"
)
)
from IPython import embed
embed()