Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Run this command to interactively debug: | |
PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py | |
Adapted from: | |
https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py | |
""" | |
import json | |
import os | |
import pickle | |
import random | |
import traceback | |
import warnings | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import imageio | |
import numpy as np | |
import torch | |
from decord import VideoReader, cpu | |
from einops import rearrange | |
from torch.utils.data import Dataset | |
from torchvision import transforms as T | |
from tqdm import tqdm | |
from cosmos_predict1.diffusion.training.datasets.dataset_utils import ( | |
Resize_Preprocess, | |
ToTensorVideo, | |
euler2rotm, | |
rotm2euler, | |
) | |
class Dataset_3D(Dataset): | |
def __init__( | |
self, | |
train_annotation_path, | |
val_annotation_path, | |
test_annotation_path, | |
video_path, | |
sequence_interval, | |
num_frames, | |
cam_ids, | |
accumulate_action, | |
video_size, | |
val_start_frame_interval, | |
debug=False, | |
normalize=False, | |
pre_encode=False, | |
do_evaluate=False, | |
load_t5_embeddings=False, | |
load_action=True, | |
mode="train", | |
): | |
"""Dataset class for loading 3D robot action-conditional data. | |
This dataset loads robot trajectories consisting of RGB video frames, robot states (arm positions and gripper states), | |
and computes relative actions between consecutive frames. | |
Args: | |
train_annotation_path (str): Path to training annotation files | |
val_annotation_path (str): Path to validation annotation files | |
test_annotation_path (str): Path to test annotation files | |
video_path (str): Base path to video files | |
sequence_interval (int): Interval between sampled frames in a sequence | |
num_frames (int): Number of frames to load per sequence | |
cam_ids (list): List of camera IDs to sample from | |
accumulate_action (bool): Whether to accumulate actions relative to first frame | |
video_size (list): Target size [H,W] for video frames | |
val_start_frame_interval (int): Frame sampling interval for validation/test | |
debug (bool, optional): If True, only loads subset of data. Defaults to False. | |
normalize (bool, optional): Whether to normalize video frames. Defaults to False. | |
pre_encode (bool, optional): Whether to pre-encode video frames. Defaults to False. | |
do_evaluate (bool, optional): Whether in evaluation mode. Defaults to False. | |
load_t5_embeddings (bool, optional): Whether to load T5 embeddings. Defaults to False. | |
load_action (bool, optional): Whether to load actions. Defaults to True. | |
mode (str, optional): Dataset mode - 'train', 'val' or 'test'. Defaults to 'train'. | |
The dataset loads robot trajectories and computes: | |
- RGB video frames from specified camera views | |
- Robot arm states (xyz position + euler angles) | |
- Gripper states (binary open/closed) | |
- Relative actions between consecutive frames | |
Actions are computed as relative transforms between frames: | |
- Translation: xyz offset in previous frame's coordinate frame | |
- Rotation: euler angles of relative rotation | |
- Gripper: binary gripper state | |
Returns dict with: | |
- video: RGB frames tensor [T,C,H,W] | |
- action: Action tensor [T-1,7] | |
- video_name: Dict with episode/frame metadata | |
- latent: Pre-encoded video features if pre_encode=True | |
""" | |
super().__init__() | |
if mode == "train": | |
self.data_path = train_annotation_path | |
self.start_frame_interval = 1 | |
elif mode == "val": | |
self.data_path = val_annotation_path | |
self.start_frame_interval = val_start_frame_interval | |
elif mode == "test": | |
self.data_path = test_annotation_path | |
self.start_frame_interval = val_start_frame_interval | |
self.video_path = video_path | |
self.sequence_interval = sequence_interval | |
self.mode = mode | |
self.sequence_length = num_frames | |
self.normalize = normalize | |
self.pre_encode = pre_encode | |
self.load_t5_embeddings = load_t5_embeddings | |
self.load_action = load_action | |
self.cam_ids = cam_ids | |
self.accumulate_action = accumulate_action | |
self.action_dim = 7 # ee xyz (3) + ee euler (3) + gripper(1) | |
self.c_act_scaler = [20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 1.0] | |
self.c_act_scaler = np.array(self.c_act_scaler, dtype=float) | |
self.ann_files = self._init_anns(self.data_path) | |
self.samples = self._init_sequences(self.ann_files) | |
self.samples = sorted(self.samples, key=lambda x: (x["ann_file"], x["frame_ids"][0])) | |
if debug and not do_evaluate: | |
self.samples = self.samples[0:10] | |
self.wrong_number = 0 | |
self.transform = T.Compose([T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]) | |
self.training = False | |
self.preprocess = T.Compose( | |
[ | |
ToTensorVideo(), | |
Resize_Preprocess(tuple(video_size)), # 288 512 | |
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
self.not_norm_preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) | |
def __str__(self): | |
return f"{len(self.ann_files)} samples from {self.data_path}" | |
def _init_anns(self, data_dir): | |
ann_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")] | |
return ann_files | |
def _init_sequences(self, ann_files): | |
samples = [] | |
with ThreadPoolExecutor(32) as executor: | |
future_to_ann_file = { | |
executor.submit(self._load_and_process_ann_file, ann_file): ann_file for ann_file in ann_files | |
} | |
for future in tqdm(as_completed(future_to_ann_file), total=len(ann_files)): | |
samples.extend(future.result()) | |
return samples | |
def _load_and_process_ann_file(self, ann_file): | |
samples = [] | |
with open(ann_file, "r") as f: | |
ann = json.load(f) | |
n_frames = len(ann["state"]) | |
for frame_i in range(0, n_frames, self.start_frame_interval): | |
sample = dict() | |
sample["ann_file"] = ann_file | |
sample["frame_ids"] = [] | |
curr_frame_i = frame_i | |
while True: | |
if curr_frame_i > (n_frames - 1): | |
break | |
sample["frame_ids"].append(curr_frame_i) | |
if len(sample["frame_ids"]) == self.sequence_length: | |
break | |
curr_frame_i += self.sequence_interval | |
# make sure there are sequence_length number of frames | |
if len(sample["frame_ids"]) == self.sequence_length: | |
samples.append(sample) | |
return samples | |
def __len__(self): | |
return len(self.samples) | |
def _load_video(self, video_path, frame_ids): | |
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) | |
assert (np.array(frame_ids) < len(vr)).all() | |
assert (np.array(frame_ids) >= 0).all() | |
vr.seek(0) | |
frame_data = vr.get_batch(frame_ids).asnumpy() | |
return frame_data | |
def _get_frames(self, label, frame_ids, cam_id, pre_encode): | |
if pre_encode: | |
raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") | |
else: | |
video_path = label["videos"][cam_id]["video_path"] | |
video_path = os.path.join(self.video_path, video_path) | |
frames = self._load_video(video_path, frame_ids) | |
frames = frames.astype(np.uint8) | |
frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) | |
def printvideo(videos, filename): | |
t_videos = rearrange(videos, "f c h w -> f h w c") | |
t_videos = ( | |
((t_videos / 2.0 + 0.5).clamp(0, 1) * 255).detach().to(dtype=torch.uint8).cpu().contiguous().numpy() | |
) | |
print(t_videos.shape) | |
writer = imageio.get_writer(filename, fps=4) # fps 是帧率 | |
for frame in t_videos: | |
writer.append_data(frame) # 1 4 13 23 # fp16 24 76 456 688 | |
if self.normalize: | |
frames = self.preprocess(frames) | |
else: | |
frames = self.not_norm_preprocess(frames) | |
frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) | |
return frames | |
def _get_obs(self, label, frame_ids, cam_id, pre_encode): | |
if cam_id is None: | |
temp_cam_id = random.choice(self.cam_ids) | |
else: | |
temp_cam_id = cam_id | |
frames = self._get_frames(label, frame_ids, cam_id=temp_cam_id, pre_encode=pre_encode) | |
return frames, temp_cam_id | |
def _get_robot_states(self, label, frame_ids): | |
all_states = np.array(label["state"]) | |
all_cont_gripper_states = np.array(label["continuous_gripper_state"]) | |
states = all_states[frame_ids] | |
cont_gripper_states = all_cont_gripper_states[frame_ids] | |
arm_states = states[:, :6] | |
assert arm_states.shape[0] == self.sequence_length | |
assert cont_gripper_states.shape[0] == self.sequence_length | |
return arm_states, cont_gripper_states | |
def _get_all_robot_states(self, label, frame_ids): | |
all_states = np.array(label["state"]) | |
all_cont_gripper_states = np.array(label["continuous_gripper_state"]) | |
states = all_states[frame_ids] | |
cont_gripper_states = all_cont_gripper_states[frame_ids] | |
arm_states = states[:, :6] | |
return arm_states, cont_gripper_states | |
def _get_all_actions(self, arm_states, gripper_states, accumulate_action): | |
action_num = arm_states.shape[0] - 1 | |
action = np.zeros((action_num, self.action_dim)) | |
if accumulate_action: | |
first_xyz = arm_states[0, 0:3] | |
first_rpy = arm_states[0, 3:6] | |
first_rotm = euler2rotm(first_rpy) | |
for k in range(1, action_num + 1): | |
curr_xyz = arm_states[k, 0:3] | |
curr_rpy = arm_states[k, 3:6] | |
curr_gripper = gripper_states[k] | |
curr_rotm = euler2rotm(curr_rpy) | |
rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) | |
rel_rotm = first_rotm.T @ curr_rotm | |
rel_rpy = rotm2euler(rel_rotm) | |
action[k - 1, 0:3] = rel_xyz | |
action[k - 1, 3:6] = rel_rpy | |
action[k - 1, 6] = curr_gripper | |
else: | |
for k in range(1, action_num + 1): | |
prev_xyz = arm_states[k - 1, 0:3] | |
prev_rpy = arm_states[k - 1, 3:6] | |
prev_rotm = euler2rotm(prev_rpy) | |
curr_xyz = arm_states[k, 0:3] | |
curr_rpy = arm_states[k, 3:6] | |
curr_gripper = gripper_states[k] | |
curr_rotm = euler2rotm(curr_rpy) | |
rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) | |
rel_rotm = prev_rotm.T @ curr_rotm | |
rel_rpy = rotm2euler(rel_rotm) | |
action[k - 1, 0:3] = rel_xyz | |
action[k - 1, 3:6] = rel_rpy | |
action[k - 1, 6] = curr_gripper | |
return torch.from_numpy(action) # (l - 1, act_dim) | |
def _get_actions(self, arm_states, gripper_states, accumulate_action): | |
action = np.zeros((self.sequence_length - 1, self.action_dim)) | |
if accumulate_action: | |
first_xyz = arm_states[0, 0:3] | |
first_rpy = arm_states[0, 3:6] | |
first_rotm = euler2rotm(first_rpy) | |
for k in range(1, self.sequence_length): | |
curr_xyz = arm_states[k, 0:3] | |
curr_rpy = arm_states[k, 3:6] | |
curr_gripper = gripper_states[k] | |
curr_rotm = euler2rotm(curr_rpy) | |
rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) | |
rel_rotm = first_rotm.T @ curr_rotm | |
rel_rpy = rotm2euler(rel_rotm) | |
action[k - 1, 0:3] = rel_xyz | |
action[k - 1, 3:6] = rel_rpy | |
action[k - 1, 6] = curr_gripper | |
else: | |
for k in range(1, self.sequence_length): | |
prev_xyz = arm_states[k - 1, 0:3] | |
prev_rpy = arm_states[k - 1, 3:6] | |
prev_rotm = euler2rotm(prev_rpy) | |
curr_xyz = arm_states[k, 0:3] | |
curr_rpy = arm_states[k, 3:6] | |
curr_gripper = gripper_states[k] | |
curr_rotm = euler2rotm(curr_rpy) | |
rel_xyz = np.dot(prev_rotm.T, curr_xyz - prev_xyz) | |
rel_rotm = prev_rotm.T @ curr_rotm | |
rel_rpy = rotm2euler(rel_rotm) | |
action[k - 1, 0:3] = rel_xyz | |
action[k - 1, 3:6] = rel_rpy | |
action[k - 1, 6] = curr_gripper | |
return torch.from_numpy(action) # (l - 1, act_dim) | |
def __getitem__(self, index, cam_id=None, return_video=False): | |
if self.mode != "train": | |
np.random.seed(index) | |
random.seed(index) | |
try: | |
sample = self.samples[index] | |
ann_file = sample["ann_file"] | |
frame_ids = sample["frame_ids"] | |
with open(ann_file, "r") as f: | |
label = json.load(f) | |
arm_states, gripper_states = self._get_robot_states(label, frame_ids) | |
actions = self._get_actions(arm_states, gripper_states, self.accumulate_action) | |
actions *= self.c_act_scaler | |
data = dict() | |
if self.load_action: | |
data["action"] = actions.float() | |
if self.pre_encode: | |
raise NotImplementedError("Pre-encoded videos are not supported for this dataset.") | |
else: | |
video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False) | |
video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] | |
data["video"] = video.to(dtype=torch.uint8) | |
data["annotation_file"] = ann_file | |
# NOTE: __key__ is used to uniquely identify the sample, required for callback functions | |
if "episode_id" in label: | |
data["__key__"] = label["episode_id"] | |
else: | |
data["__key__"] = label["original_path"] | |
# Just add these to fit the interface | |
if self.load_t5_embeddings: | |
t5_embedding_path = ann_file.replace(".json", ".pickle") | |
with open(t5_embedding_path, "rb") as f: | |
data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]) | |
else: | |
data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16) | |
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) | |
data["fps"] = 4 | |
data["image_size"] = 256 * torch.ones(4) # TODO: Does this matter? | |
data["num_frames"] = self.sequence_length | |
data["padding_mask"] = torch.zeros(1, 256, 256) | |
return data | |
except Exception: | |
warnings.warn( | |
f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped " | |
f"(by randomly sampling another sample in the same dataset)." | |
) | |
warnings.warn("FULL TRACEBACK:") | |
warnings.warn(traceback.format_exc()) | |
self.wrong_number += 1 | |
print(self.wrong_number) | |
return self[np.random.randint(len(self.samples))] | |
if __name__ == "__main__": | |
dataset = Dataset_3D( | |
train_annotation_path="datasets/bridge/annotation/train", | |
val_annotation_path="datasets/bridge/annotation/val", | |
test_annotation_path="datasets/bridge/annotation/test", | |
video_path="datasets/bridge/", | |
sequence_interval=1, | |
num_frames=2, | |
cam_ids=[0], | |
accumulate_action=False, | |
video_size=[256, 320], | |
val_start_frame_interval=1, | |
mode="train", | |
load_t5_embeddings=True, | |
) | |
indices = [0, 13, 200, -1] | |
for idx in indices: | |
print( | |
( | |
f"{idx=} " | |
f"{dataset[idx]['video'].sum()=}\n" | |
f"{dataset[idx]['video'].shape=}\n" | |
f"{dataset[idx]['video_name']=}\n" | |
f"{dataset[idx]['action'].sum()=}\n" | |
"---" | |
) | |
) | |
from IPython import embed | |
embed() | |