harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run this command to interactively debug:
PYTHONPATH=. python cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py
"""
import os
import pickle
import traceback
import warnings
import numpy as np
import torch
from decord import VideoReader, cpu
from torch.utils.data import Dataset
from cosmos_transfer1.diffusion.datasets.augmentor_provider import AUGMENTOR_OPTIONS
from cosmos_transfer1.diffusion.datasets.augmentors.control_input import VIDEO_RES_SIZE_INFO
from cosmos_transfer1.diffusion.inference.inference_utils import detect_aspect_ratio
from cosmos_transfer1.utils.lazy_config import instantiate
# mappings between control types and corresponding sub-folders names in the data folder
CTRL_TYPE_INFO = {
"keypoint": {"folder": "keypoint", "format": "pickle", "data_dict_key": "keypoint"},
"depth": {"folder": "depth", "format": "mp4", "data_dict_key": "depth"},
"lidar": {"folder": "lidar", "format": "mp4", "data_dict_key": "lidar"},
"hdmap": {"folder": "hdmap", "format": "mp4", "data_dict_key": "hdmap"},
"seg": {"folder": "seg", "format": "pickle", "data_dict_key": "segmentation"},
"edge": {"folder": None}, # Canny edge, computed on-the-fly
"vis": {"folder": None}, # Blur, computed on-the-fly
"upscale": {"folder": None}, # Computed on-the-fly
}
class ExampleTransferDataset(Dataset):
def __init__(self, dataset_dir, num_frames, resolution, hint_key="control_input_vis", is_train=True):
"""Dataset class for loading video-text-to-video generation data with control inputs.
Args:
dataset_dir (str): Base path to the dataset directory
num_frames (int): Number of consecutive frames to load per sequence
resolution (str): resolution of the target video size
hint_key (str): The hint key for loading the correct control input data modality
is_train (bool): Whether this is for training
NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration.
"""
super().__init__()
self.dataset_dir = dataset_dir
self.sequence_length = num_frames
self.is_train = is_train
self.resolution = resolution
assert (
resolution in VIDEO_RES_SIZE_INFO.keys()
), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO."
# Control input setup with file formats
self.ctrl_type = hint_key.replace("control_input_", "")
self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type]
# Set up directories - only collect paths
video_dir = os.path.join(self.dataset_dir, "videos")
self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")]
self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl")
print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.")
# Set up preprocessing and augmentation
augmentor_name = f"video_ctrlnet_augmentor_{hint_key}"
augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution)
self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()}
def _sample_frames(self, video_path):
"""Sample frames from video and get metadata"""
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
n_frames = len(vr)
# Calculate valid start frame range
max_start_idx = n_frames - self.sequence_length
if max_start_idx < 0: # Video is too short
return None, None, None
# Sample start frame
start_frame = np.random.randint(0, max_start_idx + 1)
frame_ids = list(range(start_frame, start_frame + self.sequence_length))
# Load frames
frames = vr.get_batch(frame_ids).asnumpy()
frames = frames.astype(np.uint8)
try:
fps = vr.get_avg_fps()
except Exception: # failed to read FPS
fps = 24
return frames, frame_ids, fps
def _load_control_data(self, sample):
"""Load control data for the video clip."""
data_dict = {}
frame_ids = sample["frame_ids"]
ctrl_path = sample["ctrl_path"]
try:
if self.ctrl_type == "seg":
with open(ctrl_path, "rb") as f:
ctrl_data = pickle.load(f)
# key should match line 982 at cosmos_transfer1/diffusion/datasets/augmentors/control_input.py
data_dict["segmentation"] = ctrl_data
elif self.ctrl_type == "keypoint":
with open(ctrl_path, "rb") as f:
ctrl_data = pickle.load(f)
data_dict["keypoint"] = ctrl_data
elif self.ctrl_type == "depth":
vr = VideoReader(ctrl_path, ctx=cpu(0))
# Ensure the depth video has the same number of frames
assert len(vr) >= frame_ids[-1] + 1, f"Depth video {ctrl_path} has fewer frames than main video"
# Load the corresponding frames
depth_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C]
depth_frames = torch.from_numpy(depth_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video
data_dict["depth"] = {
"video": depth_frames,
"frame_start": frame_ids[0],
"frame_end": frame_ids[-1],
}
elif self.ctrl_type == "lidar":
vr = VideoReader(ctrl_path, ctx=cpu(0))
# Ensure the lidar depth video has the same number of frames
assert len(vr) >= frame_ids[-1] + 1, f"Lidar video {ctrl_path} has fewer frames than main video"
# Load the corresponding frames
lidar_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C]
lidar_frames = torch.from_numpy(lidar_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video
data_dict["lidar"] = {
"video": lidar_frames,
"frame_start": frame_ids[0],
"frame_end": frame_ids[-1],
}
elif self.ctrl_type == "hdmap":
vr = VideoReader(ctrl_path, ctx=cpu(0))
# Ensure the hdmap video has the same number of frames
assert len(vr) >= frame_ids[-1] + 1, f"Hdmap video {ctrl_path} has fewer frames than main video"
# Load the corresponding frames
hdmap_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C]
hdmap_frames = torch.from_numpy(hdmap_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video
data_dict["hdmap"] = {
"video": hdmap_frames,
"frame_start": frame_ids[0],
"frame_end": frame_ids[-1],
}
except Exception as e:
warnings.warn(f"Failed to load control data from {ctrl_path}: {str(e)}")
return None
return data_dict
def __getitem__(self, index):
max_retries = 3
for _ in range(max_retries):
try:
video_path = self.video_paths[index]
video_name = os.path.basename(video_path).replace(".mp4", "")
# Sample frames
frames, frame_ids, fps = self._sample_frames(video_path)
if frames is None: # Invalid video or too short
index = np.random.randint(len(self.video_paths))
continue
data = dict()
# Process video frames
video = torch.from_numpy(frames).permute(3, 0, 1, 2) # [T,H,W,C] -> [C,T,H,W]
aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (W, H)
# Basic data
data["video"] = video
data["aspect_ratio"] = aspect_ratio
# Load T5 embeddings
if self.ctrl_type in ["hdmap", "lidar"]:
# AV data load captions differently
data["video_name"] = {
"video_path": video_path,
"t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pkl"),
"start_frame_id": str(frame_ids[0]),
}
with open(data["video_name"]["t5_embedding_path"], "rb") as f:
t5_embedding = pickle.load(f)["pickle"]["ground_truth"]["embeddings"]["t5_xxl"]
# Ensure t5_embedding is a numpy array
if isinstance(t5_embedding, list):
t5_embedding = np.array(t5_embedding[0] if len(t5_embedding) > 0 else t5_embedding)
data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) # .cuda()
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) # .cuda()
else:
data["video_name"] = {
"video_path": video_path,
"t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pickle"),
"start_frame_id": str(frame_ids[0]),
}
with open(data["video_name"]["t5_embedding_path"], "rb") as f:
t5_embedding = pickle.load(f)
# Ensure t5_embedding is a numpy array
if isinstance(t5_embedding, list):
t5_embedding = np.array(t5_embedding[0] if len(t5_embedding) > 0 else t5_embedding)
data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) # .cuda()
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) # .cuda()
# Add metadata
data["fps"] = fps
data["frame_start"] = frame_ids[0]
data["frame_end"] = frame_ids[-1] + 1
data["num_frames"] = self.sequence_length
data["image_size"] = torch.tensor([704, 1280, 704, 1280]) # .cuda()
data["padding_mask"] = torch.zeros(1, 704, 1280) # .cuda()
if self.ctrl_type:
ctrl_data = self._load_control_data(
{
"ctrl_path": os.path.join(
self.dataset_dir,
self.ctrl_data_pth_config["folder"],
f"{video_name}.{self.ctrl_data_pth_config['format']}",
)
if self.ctrl_data_pth_config["folder"] is not None
else None,
"frame_ids": frame_ids,
}
)
if ctrl_data is None: # Control data loading failed
index = np.random.randint(len(self.video_paths))
continue
data.update(ctrl_data)
# The ctrl_data above is the 'raw' data loaded (e.g. a loaded segmentation pkl).
# Next, we process it into the control input "video" tensor that the model expects.
# This is done in the augmentor.
for _, aug_fn in self.augmentor.items():
data = aug_fn(data)
return data
except Exception:
warnings.warn(
f"Invalid data encountered: {self.video_paths[index]}. Skipped "
f"(by randomly sampling another sample in the same dataset)."
)
warnings.warn("FULL TRACEBACK:")
warnings.warn(traceback.format_exc())
if _ == max_retries - 1:
raise RuntimeError(f"Failed to load data after {max_retries} attempts")
index = np.random.randint(len(self.video_paths))
return
def __len__(self):
return len(self.video_paths)
def __str__(self):
return f"{len(self.video_paths)} samples from {self.dataset_dir}"
class AVTransferDataset(ExampleTransferDataset):
def __init__(
self,
dataset_dir,
num_frames,
resolution,
view_keys,
hint_key="control_input_hdmap",
sample_n_views=-1,
caption_view_idx_map=None,
is_train=True,
load_mv_emb=False,
):
"""Dataset class for loading video-text-to-video generation data with control inputs.
Args:
dataset_dir (str): Base path to the dataset directory
num_frames (int): Number of consecutive frames to load per sequence
resolution (str): resolution of the target video size
hint_key (str): The hint key for loading the correct control input data modality
view_keys (list[str]): list of view names that the dataloader should load
sample_n_views (int): Number of views to sample
caption_view_idx_map (dict): Optional dictionary mapping index in view_keys to index in model.view_embeddings
is_train (bool): Whether this is for training
load_mv_emb (bool): Whether to load t5 embeddings for all views, or only for front view
NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration.
"""
super(ExampleTransferDataset, self).__init__()
self.dataset_dir = dataset_dir
self.sequence_length = num_frames
self.is_train = is_train
self.resolution = resolution
self.view_keys = view_keys
self.load_mv_emb = load_mv_emb
assert (
resolution in VIDEO_RES_SIZE_INFO.keys()
), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO."
# Control input setup with file formats
self.ctrl_type = hint_key.replace("control_input_", "")
self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type]
# Set up directories - only collect paths
video_dir = os.path.join(self.dataset_dir, "videos", "pinhole_front")
self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")]
self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl")
cache_dir = os.path.join(self.dataset_dir, "cache")
self.prefix_t5_embeddings = {}
for view_key in view_keys:
with open(os.path.join(cache_dir, f"prefix_{view_key}.pkl"), "rb") as f:
self.prefix_t5_embeddings[view_key] = pickle.load(f)
if caption_view_idx_map is None:
self.caption_view_idx_map = dict([(i, i) for i in range(len(self.view_keys))])
else:
self.caption_view_idx_map = caption_view_idx_map
self.sample_n_views = sample_n_views
print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.")
# Set up preprocessing and augmentation
augmentor_name = f"video_ctrlnet_augmentor_{hint_key}"
augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution)
self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()}
def _load_video(self, video_path, frame_ids):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
assert (np.array(frame_ids) < len(vr)).all()
assert (np.array(frame_ids) >= 0).all()
vr.seek(0)
frame_data = vr.get_batch(frame_ids).asnumpy()
try:
fps = vr.get_avg_fps()
except Exception: # failed to read FPS
fps = 24
return frame_data, fps
def __getitem__(self, index):
max_retries = 3
for _ in range(max_retries):
try:
video_path = self.video_paths[index]
video_name = os.path.basename(video_path).replace(".mp4", "")
data = dict()
ctrl_videos = []
videos = []
t5_embeddings = []
t5_masks = []
view_indices = [i for i in range(len(self.view_keys))]
view_indices_conditioning = []
if self.sample_n_views > 1:
sampled_idx = np.random.choice(
np.arange(1, len(view_indices)),
size=min(self.sample_n_views - 1, len(view_indices) - 1),
replace=False,
)
sampled_idx = np.concatenate(
[
[
0,
],
sampled_idx,
]
)
sampled_idx.sort()
view_indices = sampled_idx.tolist()
frame_ids = None
fps = None
for view_index in view_indices:
view_key = self.view_keys[view_index]
if frame_ids is None:
frames, frame_ids, fps = self._sample_frames(video_path)
if frames is None: # Invalid video or too short
raise Exception(f"Failed to load frames {video_path}")
else:
frames, fps = self._load_video(
os.path.join(self.dataset_dir, "videos", view_key, os.path.basename(video_path)), frame_ids
)
# Process video frames
video = torch.from_numpy(frames)
video = video.permute(3, 0, 1, 2) # Rearrange from [T, C, H, W] to [C, T, H, W]
aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (W, H)
videos.append(video)
if video_name[-2] == "_" and video_name[-1].isdigit():
video_name_emb = video_name[:-2]
else:
video_name_emb = video_name
if self.load_mv_emb or view_key == "pinhole_front":
t5_embedding_path = os.path.join(self.dataset_dir, "t5_xxl", view_key, f"{video_name_emb}.pkl")
with open(t5_embedding_path, "rb") as f:
t5_embedding = pickle.load(f)[0]
if self.load_mv_emb:
t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0)
else:
# use camera prompt
t5_embedding = self.prefix_t5_embeddings[view_key]
t5_embedding = torch.from_numpy(t5_embedding)
t5_mask = torch.ones(t5_embedding.shape[0], dtype=torch.int64)
if t5_embedding.shape[0] < 512:
t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0)
t5_mask = torch.cat([t5_mask, torch.zeros(512 - t5_mask.shape[0])], dim=0)
else:
t5_embedding = t5_embedding[:512]
t5_mask = t5_mask[:512]
t5_embeddings.append(t5_embedding)
t5_masks.append(t5_mask)
caption_viewid = self.caption_view_idx_map[view_index]
view_indices_conditioning.append(torch.ones(video.shape[1]) * caption_viewid)
if self.ctrl_type:
v_ctrl_data = self._load_control_data(
{
"ctrl_path": os.path.join(
self.dataset_dir,
self.ctrl_data_pth_config["folder"],
view_key,
f"{video_name}.{self.ctrl_data_pth_config['format']}",
)
if self.ctrl_data_pth_config["folder"] is not None
else None,
"frame_ids": frame_ids,
}
)
if v_ctrl_data is None: # Control data loading failed
raise Exception("Failed to load v_ctrl_data")
ctrl_videos.append(v_ctrl_data[self.ctrl_type]["video"])
video = torch.cat(videos, dim=1)
ctrl_videos = torch.cat(ctrl_videos, dim=1)
t5_embedding = torch.cat(t5_embeddings, dim=0)
view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0)
# Basic data
data["video"] = video
data["video_name"] = video_name
data["aspect_ratio"] = aspect_ratio
data["t5_text_embeddings"] = t5_embedding
data["t5_text_mask"] = torch.cat(t5_masks)
data["view_indices"] = view_indices_conditioning.contiguous()
data["frame_repeat"] = torch.zeros(len(view_indices))
# Add metadata
data["fps"] = fps
data["frame_start"] = frame_ids[0]
data["frame_end"] = frame_ids[-1] + 1
data["num_frames"] = self.sequence_length
data["image_size"] = torch.tensor([704, 1280, 704, 1280])
data["padding_mask"] = torch.zeros(1, 704, 1280)
data[self.ctrl_type] = dict()
data[self.ctrl_type]["video"] = ctrl_videos
# The ctrl_data above is the 'raw' data loaded (e.g. a loaded lidar pkl).
# Next, we process it into the control input "video" tensor that the model expects.
# This is done in the augmentor.
for _, aug_fn in self.augmentor.items():
data = aug_fn(data)
return data
except Exception:
warnings.warn(
f"Invalid data encountered: {self.video_paths[index]}. Skipped "
f"(by randomly sampling another sample in the same dataset)."
)
warnings.warn("FULL TRACEBACK:")
warnings.warn(traceback.format_exc())
if _ == max_retries - 1:
raise RuntimeError(f"Failed to load data after {max_retries} attempts")
index = np.random.randint(len(self.video_paths))
return
if __name__ == "__main__":
"""
Sanity check for the dataset.
"""
control_input_key = "control_input_lidar"
visualize_control_input = True
dataset = AVTransferDataset(
dataset_dir="datasets/waymo_transfer1",
view_keys=["pinhole_front"],
hint_key=control_input_key,
num_frames=121,
resolution="720",
is_train=True,
)
print("finished init dataset")
indices = [0, 12, 100, -1]
for idx in indices:
data = dataset[idx]
print(
(
f"{idx=} "
f"{data['frame_start']=}\n"
f"{data['frame_end']=}\n"
f"{data['video'].sum()=}\n"
f"{data['video'].shape=}\n"
f"{data[control_input_key].shape=}\n" # should match the video shape
f"{data['video_name']=}\n"
f"{data['t5_text_embeddings'].shape=}\n"
"---"
)
)
if visualize_control_input:
import imageio
control_input_tensor = data[control_input_key].permute(1, 2, 3, 0).cpu().numpy()
video_name = f"{control_input_key}.mp4"
imageio.mimsave(video_name, control_input_tensor, fps=24)