# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Run this command to interactively debug: PYTHONPATH=. python cosmos_predict1/diffusion/training/datasets/dataset_multiview.py Adapted from: https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py """ import os import pickle import traceback import warnings from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import torch from decord import VideoReader, cpu from torch.utils.data import Dataset from torchvision import transforms as T from tqdm import tqdm from cosmos_predict1.diffusion.training.datasets.dataset_utils import Resize_Preprocess, ToTensorVideo class Dataset(Dataset): def __init__( self, dataset_dir, sequence_interval, num_frames, view_keys, video_size, start_frame_interval=1, ): """Dataset class for loading image-text-to-video generation data. Args: dataset_dir (str): Base path to the dataset directory sequence_interval (int): Interval between sampled frames in a sequence num_frames (int): Number of frames to load per sequence video_size (list): Target size [H,W] for video frames Returns dict with: - video: RGB frames tensor [T,C,H,W] - video_name: Dict with episode/frame metadata """ super().__init__() self.dataset_dir = dataset_dir self.start_frame_interval = start_frame_interval self.sequence_interval = sequence_interval self.sequence_length = num_frames self.view_keys = view_keys video_dir = os.path.join(self.dataset_dir, "videos") self.video_paths = [ os.path.join(video_dir, view_keys[0], f) for f in os.listdir(os.path.join(video_dir, view_keys[0])) ] print(f"{len(self.video_paths)} videos in total") self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") self.samples = self._init_samples(self.video_paths) self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0])) print(f"{len(self.samples)} samples in total") self.wrong_number = 0 self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))]) cache_dir = os.path.join(self.dataset_dir, "cache") self.prefix_t5_embeddings = {} for view_key in view_keys: with open(os.path.join(cache_dir, f"prefix_t5_embeddings_{view_key}.pickle"), "rb") as f: self.prefix_t5_embeddings[view_key] = pickle.load(f)[0] def __str__(self): return f"{len(self.video_paths)} samples from {self.dataset_dir}" def _init_samples(self, video_paths): samples = [] with ThreadPoolExecutor(32) as executor: future_to_video_path = { executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths } for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)): samples.extend(future.result()) return samples def _load_and_process_video_path(self, video_path): vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) n_frames = len(vr) samples = [] for frame_i in range(0, n_frames, self.start_frame_interval): sample = dict() sample["video_path"] = video_path sample["t5_embedding_path"] = os.path.join( self.t5_dir, os.path.basename(os.path.dirname(video_path)), os.path.basename(video_path).replace(".mp4", ".pickle"), ) sample["frame_ids"] = [] curr_frame_i = frame_i while True: if curr_frame_i > (n_frames - 1): break sample["frame_ids"].append(curr_frame_i) if len(sample["frame_ids"]) == self.sequence_length: break curr_frame_i += self.sequence_interval # make sure there are sequence_length number of frames if len(sample["frame_ids"]) == self.sequence_length: samples.append(sample) return samples def __len__(self): return len(self.samples) def _load_video(self, video_path, frame_ids): vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) assert (np.array(frame_ids) < len(vr)).all() assert (np.array(frame_ids) >= 0).all() vr.seek(0) frame_data = vr.get_batch(frame_ids).asnumpy() try: fps = vr.get_avg_fps() except Exception: # failed to read FPS fps = 24 return frame_data, fps def _get_frames(self, video_path, frame_ids): frames, fps = self._load_video(video_path, frame_ids) frames = frames.astype(np.uint8) frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # (l, c, h, w) frames = self.preprocess(frames) frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8) return frames, fps def __getitem__(self, index): try: sample = self.samples[index] video_path = sample["video_path"] frame_ids = sample["frame_ids"] t5_embedding_path = sample["t5_embedding_path"] data = dict() videos = [] t5_embeddings = [] for view_key in self.view_keys: video, fps = self._get_frames( os.path.join(os.path.dirname(os.path.dirname(video_path)), view_key, os.path.basename(video_path)), frame_ids, ) video = video.permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W] videos.append(video) with open( os.path.join( os.path.dirname(os.path.dirname(t5_embedding_path)), view_key, os.path.basename(t5_embedding_path), ), "rb", ) as f: t5_embedding = pickle.load(f)[0] t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0) t5_embedding = torch.from_numpy(t5_embedding) if t5_embedding.shape[0] < 512: t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0) t5_embeddings.append(t5_embedding) video = torch.cat(videos, dim=1) t5_embedding = torch.cat(t5_embeddings, dim=0) data["video"] = video data["video_name"] = { "video_path": video_path, "t5_embedding_path": t5_embedding_path, "start_frame_id": str(frame_ids[0]), } data["t5_text_embeddings"] = t5_embedding data["t5_text_mask"] = torch.ones(512 * len(self.view_keys), dtype=torch.int64) data["fps"] = fps data["image_size"] = torch.tensor([704, 1280, 704, 1280]) data["num_frames"] = self.sequence_length data["padding_mask"] = torch.zeros(1, 704, 1280) return data except Exception: warnings.warn( f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped " f"(by randomly sampling another sample in the same dataset)." ) warnings.warn("FULL TRACEBACK:") warnings.warn(traceback.format_exc()) self.wrong_number += 1 print(self.wrong_number) return self[np.random.randint(len(self.samples))] if __name__ == "__main__": dataset = Dataset( dataset_dir="datasets/waymo/", sequence_interval=1, num_frames=57, view_keys=[ "pinhole_front_left", "pinhole_front", "pinhole_front_right", "pinhole_side_left", "pinhole_side_right", ], video_size=[240, 360], ) indices = [0, 13, 200, -1] for idx in indices: data = dataset[idx] print( ( f"{idx=} " f"{data['video'].sum()=}\n" f"{data['video'].shape=}\n" f"{data['video_name']=}\n" f"{data['t5_text_embeddings'].shape=}\n" "---" ) )