|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Run this command to interactively debug: |
|
PYTHONPATH=. python cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py |
|
""" |
|
|
|
import os |
|
import pickle |
|
import traceback |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
from decord import VideoReader, cpu |
|
from torch.utils.data import Dataset |
|
|
|
from cosmos_transfer1.diffusion.datasets.augmentor_provider import AUGMENTOR_OPTIONS |
|
from cosmos_transfer1.diffusion.datasets.augmentors.control_input import VIDEO_RES_SIZE_INFO |
|
from cosmos_transfer1.diffusion.inference.inference_utils import detect_aspect_ratio |
|
from cosmos_transfer1.utils.lazy_config import instantiate |
|
|
|
|
|
CTRL_TYPE_INFO = { |
|
"keypoint": {"folder": "keypoint", "format": "pickle", "data_dict_key": "keypoint"}, |
|
"depth": {"folder": "depth", "format": "mp4", "data_dict_key": "depth"}, |
|
"lidar": {"folder": "lidar", "format": "mp4", "data_dict_key": "lidar"}, |
|
"hdmap": {"folder": "hdmap", "format": "mp4", "data_dict_key": "hdmap"}, |
|
"seg": {"folder": "seg", "format": "pickle", "data_dict_key": "segmentation"}, |
|
"edge": {"folder": None}, |
|
"vis": {"folder": None}, |
|
"upscale": {"folder": None}, |
|
} |
|
|
|
|
|
class ExampleTransferDataset(Dataset): |
|
def __init__(self, dataset_dir, num_frames, resolution, hint_key="control_input_vis", is_train=True): |
|
"""Dataset class for loading video-text-to-video generation data with control inputs. |
|
|
|
Args: |
|
dataset_dir (str): Base path to the dataset directory |
|
num_frames (int): Number of consecutive frames to load per sequence |
|
resolution (str): resolution of the target video size |
|
hint_key (str): The hint key for loading the correct control input data modality |
|
is_train (bool): Whether this is for training |
|
|
|
NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration. |
|
""" |
|
super().__init__() |
|
self.dataset_dir = dataset_dir |
|
self.sequence_length = num_frames |
|
self.is_train = is_train |
|
self.resolution = resolution |
|
assert ( |
|
resolution in VIDEO_RES_SIZE_INFO.keys() |
|
), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO." |
|
|
|
|
|
self.ctrl_type = hint_key.replace("control_input_", "") |
|
self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type] |
|
|
|
|
|
video_dir = os.path.join(self.dataset_dir, "videos") |
|
self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] |
|
self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") |
|
print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.") |
|
|
|
|
|
augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" |
|
augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) |
|
self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()} |
|
|
|
def _sample_frames(self, video_path): |
|
"""Sample frames from video and get metadata""" |
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) |
|
n_frames = len(vr) |
|
|
|
|
|
max_start_idx = n_frames - self.sequence_length |
|
if max_start_idx < 0: |
|
return None, None, None |
|
|
|
|
|
start_frame = np.random.randint(0, max_start_idx + 1) |
|
frame_ids = list(range(start_frame, start_frame + self.sequence_length)) |
|
|
|
|
|
frames = vr.get_batch(frame_ids).asnumpy() |
|
frames = frames.astype(np.uint8) |
|
try: |
|
fps = vr.get_avg_fps() |
|
except Exception: |
|
fps = 24 |
|
|
|
return frames, frame_ids, fps |
|
|
|
def _load_control_data(self, sample): |
|
"""Load control data for the video clip.""" |
|
data_dict = {} |
|
frame_ids = sample["frame_ids"] |
|
ctrl_path = sample["ctrl_path"] |
|
try: |
|
if self.ctrl_type == "seg": |
|
with open(ctrl_path, "rb") as f: |
|
ctrl_data = pickle.load(f) |
|
|
|
data_dict["segmentation"] = ctrl_data |
|
elif self.ctrl_type == "keypoint": |
|
with open(ctrl_path, "rb") as f: |
|
ctrl_data = pickle.load(f) |
|
data_dict["keypoint"] = ctrl_data |
|
elif self.ctrl_type == "depth": |
|
vr = VideoReader(ctrl_path, ctx=cpu(0)) |
|
|
|
assert len(vr) >= frame_ids[-1] + 1, f"Depth video {ctrl_path} has fewer frames than main video" |
|
|
|
|
|
depth_frames = vr.get_batch(frame_ids).asnumpy() |
|
depth_frames = torch.from_numpy(depth_frames).permute(3, 0, 1, 2) |
|
data_dict["depth"] = { |
|
"video": depth_frames, |
|
"frame_start": frame_ids[0], |
|
"frame_end": frame_ids[-1], |
|
} |
|
elif self.ctrl_type == "lidar": |
|
vr = VideoReader(ctrl_path, ctx=cpu(0)) |
|
|
|
assert len(vr) >= frame_ids[-1] + 1, f"Lidar video {ctrl_path} has fewer frames than main video" |
|
|
|
lidar_frames = vr.get_batch(frame_ids).asnumpy() |
|
lidar_frames = torch.from_numpy(lidar_frames).permute(3, 0, 1, 2) |
|
data_dict["lidar"] = { |
|
"video": lidar_frames, |
|
"frame_start": frame_ids[0], |
|
"frame_end": frame_ids[-1], |
|
} |
|
elif self.ctrl_type == "hdmap": |
|
vr = VideoReader(ctrl_path, ctx=cpu(0)) |
|
|
|
assert len(vr) >= frame_ids[-1] + 1, f"Hdmap video {ctrl_path} has fewer frames than main video" |
|
|
|
hdmap_frames = vr.get_batch(frame_ids).asnumpy() |
|
hdmap_frames = torch.from_numpy(hdmap_frames).permute(3, 0, 1, 2) |
|
data_dict["hdmap"] = { |
|
"video": hdmap_frames, |
|
"frame_start": frame_ids[0], |
|
"frame_end": frame_ids[-1], |
|
} |
|
|
|
except Exception as e: |
|
warnings.warn(f"Failed to load control data from {ctrl_path}: {str(e)}") |
|
return None |
|
|
|
return data_dict |
|
|
|
def __getitem__(self, index): |
|
max_retries = 3 |
|
for _ in range(max_retries): |
|
try: |
|
video_path = self.video_paths[index] |
|
video_name = os.path.basename(video_path).replace(".mp4", "") |
|
|
|
|
|
frames, frame_ids, fps = self._sample_frames(video_path) |
|
if frames is None: |
|
index = np.random.randint(len(self.video_paths)) |
|
continue |
|
|
|
data = dict() |
|
|
|
|
|
video = torch.from_numpy(frames).permute(3, 0, 1, 2) |
|
aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) |
|
|
|
|
|
data["video"] = video |
|
data["aspect_ratio"] = aspect_ratio |
|
|
|
|
|
if self.ctrl_type in ["hdmap", "lidar"]: |
|
|
|
data["video_name"] = { |
|
"video_path": video_path, |
|
"t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pkl"), |
|
"start_frame_id": str(frame_ids[0]), |
|
} |
|
with open(data["video_name"]["t5_embedding_path"], "rb") as f: |
|
t5_embedding = pickle.load(f)["pickle"]["ground_truth"]["embeddings"]["t5_xxl"] |
|
|
|
if isinstance(t5_embedding, list): |
|
t5_embedding = np.array(t5_embedding[0] if len(t5_embedding) > 0 else t5_embedding) |
|
data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) |
|
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) |
|
else: |
|
data["video_name"] = { |
|
"video_path": video_path, |
|
"t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pickle"), |
|
"start_frame_id": str(frame_ids[0]), |
|
} |
|
with open(data["video_name"]["t5_embedding_path"], "rb") as f: |
|
t5_embedding = pickle.load(f) |
|
|
|
if isinstance(t5_embedding, list): |
|
t5_embedding = np.array(t5_embedding[0] if len(t5_embedding) > 0 else t5_embedding) |
|
data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) |
|
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) |
|
|
|
|
|
data["fps"] = fps |
|
data["frame_start"] = frame_ids[0] |
|
data["frame_end"] = frame_ids[-1] + 1 |
|
data["num_frames"] = self.sequence_length |
|
data["image_size"] = torch.tensor([704, 1280, 704, 1280]) |
|
data["padding_mask"] = torch.zeros(1, 704, 1280) |
|
|
|
if self.ctrl_type: |
|
ctrl_data = self._load_control_data( |
|
{ |
|
"ctrl_path": os.path.join( |
|
self.dataset_dir, |
|
self.ctrl_data_pth_config["folder"], |
|
f"{video_name}.{self.ctrl_data_pth_config['format']}", |
|
) |
|
if self.ctrl_data_pth_config["folder"] is not None |
|
else None, |
|
"frame_ids": frame_ids, |
|
} |
|
) |
|
if ctrl_data is None: |
|
index = np.random.randint(len(self.video_paths)) |
|
continue |
|
data.update(ctrl_data) |
|
|
|
|
|
|
|
|
|
for _, aug_fn in self.augmentor.items(): |
|
data = aug_fn(data) |
|
|
|
return data |
|
|
|
except Exception: |
|
warnings.warn( |
|
f"Invalid data encountered: {self.video_paths[index]}. Skipped " |
|
f"(by randomly sampling another sample in the same dataset)." |
|
) |
|
warnings.warn("FULL TRACEBACK:") |
|
warnings.warn(traceback.format_exc()) |
|
if _ == max_retries - 1: |
|
raise RuntimeError(f"Failed to load data after {max_retries} attempts") |
|
index = np.random.randint(len(self.video_paths)) |
|
return |
|
|
|
def __len__(self): |
|
return len(self.video_paths) |
|
|
|
def __str__(self): |
|
return f"{len(self.video_paths)} samples from {self.dataset_dir}" |
|
|
|
|
|
class AVTransferDataset(ExampleTransferDataset): |
|
def __init__( |
|
self, |
|
dataset_dir, |
|
num_frames, |
|
resolution, |
|
view_keys, |
|
hint_key="control_input_hdmap", |
|
sample_n_views=-1, |
|
caption_view_idx_map=None, |
|
is_train=True, |
|
load_mv_emb=False, |
|
): |
|
"""Dataset class for loading video-text-to-video generation data with control inputs. |
|
|
|
Args: |
|
dataset_dir (str): Base path to the dataset directory |
|
num_frames (int): Number of consecutive frames to load per sequence |
|
resolution (str): resolution of the target video size |
|
hint_key (str): The hint key for loading the correct control input data modality |
|
view_keys (list[str]): list of view names that the dataloader should load |
|
sample_n_views (int): Number of views to sample |
|
caption_view_idx_map (dict): Optional dictionary mapping index in view_keys to index in model.view_embeddings |
|
is_train (bool): Whether this is for training |
|
load_mv_emb (bool): Whether to load t5 embeddings for all views, or only for front view |
|
NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration. |
|
""" |
|
super(ExampleTransferDataset, self).__init__() |
|
self.dataset_dir = dataset_dir |
|
self.sequence_length = num_frames |
|
self.is_train = is_train |
|
self.resolution = resolution |
|
self.view_keys = view_keys |
|
self.load_mv_emb = load_mv_emb |
|
assert ( |
|
resolution in VIDEO_RES_SIZE_INFO.keys() |
|
), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO." |
|
|
|
|
|
self.ctrl_type = hint_key.replace("control_input_", "") |
|
self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type] |
|
|
|
|
|
video_dir = os.path.join(self.dataset_dir, "videos", "pinhole_front") |
|
self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] |
|
self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") |
|
|
|
cache_dir = os.path.join(self.dataset_dir, "cache") |
|
self.prefix_t5_embeddings = {} |
|
for view_key in view_keys: |
|
with open(os.path.join(cache_dir, f"prefix_{view_key}.pkl"), "rb") as f: |
|
self.prefix_t5_embeddings[view_key] = pickle.load(f) |
|
if caption_view_idx_map is None: |
|
self.caption_view_idx_map = dict([(i, i) for i in range(len(self.view_keys))]) |
|
else: |
|
self.caption_view_idx_map = caption_view_idx_map |
|
self.sample_n_views = sample_n_views |
|
|
|
print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.") |
|
|
|
|
|
augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" |
|
augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) |
|
self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()} |
|
|
|
def _load_video(self, video_path, frame_ids): |
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) |
|
assert (np.array(frame_ids) < len(vr)).all() |
|
assert (np.array(frame_ids) >= 0).all() |
|
vr.seek(0) |
|
frame_data = vr.get_batch(frame_ids).asnumpy() |
|
try: |
|
fps = vr.get_avg_fps() |
|
except Exception: |
|
fps = 24 |
|
return frame_data, fps |
|
|
|
def __getitem__(self, index): |
|
max_retries = 3 |
|
for _ in range(max_retries): |
|
try: |
|
video_path = self.video_paths[index] |
|
video_name = os.path.basename(video_path).replace(".mp4", "") |
|
|
|
data = dict() |
|
ctrl_videos = [] |
|
videos = [] |
|
t5_embeddings = [] |
|
t5_masks = [] |
|
view_indices = [i for i in range(len(self.view_keys))] |
|
view_indices_conditioning = [] |
|
if self.sample_n_views > 1: |
|
sampled_idx = np.random.choice( |
|
np.arange(1, len(view_indices)), |
|
size=min(self.sample_n_views - 1, len(view_indices) - 1), |
|
replace=False, |
|
) |
|
sampled_idx = np.concatenate( |
|
[ |
|
[ |
|
0, |
|
], |
|
sampled_idx, |
|
] |
|
) |
|
sampled_idx.sort() |
|
view_indices = sampled_idx.tolist() |
|
|
|
frame_ids = None |
|
fps = None |
|
for view_index in view_indices: |
|
view_key = self.view_keys[view_index] |
|
if frame_ids is None: |
|
frames, frame_ids, fps = self._sample_frames(video_path) |
|
if frames is None: |
|
raise Exception(f"Failed to load frames {video_path}") |
|
|
|
else: |
|
frames, fps = self._load_video( |
|
os.path.join(self.dataset_dir, "videos", view_key, os.path.basename(video_path)), frame_ids |
|
) |
|
|
|
video = torch.from_numpy(frames) |
|
|
|
video = video.permute(3, 0, 1, 2) |
|
aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) |
|
videos.append(video) |
|
|
|
if video_name[-2] == "_" and video_name[-1].isdigit(): |
|
video_name_emb = video_name[:-2] |
|
else: |
|
video_name_emb = video_name |
|
|
|
if self.load_mv_emb or view_key == "pinhole_front": |
|
t5_embedding_path = os.path.join(self.dataset_dir, "t5_xxl", view_key, f"{video_name_emb}.pkl") |
|
with open(t5_embedding_path, "rb") as f: |
|
t5_embedding = pickle.load(f)[0] |
|
if self.load_mv_emb: |
|
t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0) |
|
else: |
|
|
|
t5_embedding = self.prefix_t5_embeddings[view_key] |
|
|
|
t5_embedding = torch.from_numpy(t5_embedding) |
|
t5_mask = torch.ones(t5_embedding.shape[0], dtype=torch.int64) |
|
if t5_embedding.shape[0] < 512: |
|
t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0) |
|
t5_mask = torch.cat([t5_mask, torch.zeros(512 - t5_mask.shape[0])], dim=0) |
|
else: |
|
t5_embedding = t5_embedding[:512] |
|
t5_mask = t5_mask[:512] |
|
t5_embeddings.append(t5_embedding) |
|
t5_masks.append(t5_mask) |
|
caption_viewid = self.caption_view_idx_map[view_index] |
|
view_indices_conditioning.append(torch.ones(video.shape[1]) * caption_viewid) |
|
|
|
if self.ctrl_type: |
|
v_ctrl_data = self._load_control_data( |
|
{ |
|
"ctrl_path": os.path.join( |
|
self.dataset_dir, |
|
self.ctrl_data_pth_config["folder"], |
|
view_key, |
|
f"{video_name}.{self.ctrl_data_pth_config['format']}", |
|
) |
|
if self.ctrl_data_pth_config["folder"] is not None |
|
else None, |
|
"frame_ids": frame_ids, |
|
} |
|
) |
|
if v_ctrl_data is None: |
|
raise Exception("Failed to load v_ctrl_data") |
|
ctrl_videos.append(v_ctrl_data[self.ctrl_type]["video"]) |
|
|
|
video = torch.cat(videos, dim=1) |
|
ctrl_videos = torch.cat(ctrl_videos, dim=1) |
|
t5_embedding = torch.cat(t5_embeddings, dim=0) |
|
view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) |
|
|
|
|
|
data["video"] = video |
|
data["video_name"] = video_name |
|
data["aspect_ratio"] = aspect_ratio |
|
data["t5_text_embeddings"] = t5_embedding |
|
data["t5_text_mask"] = torch.cat(t5_masks) |
|
data["view_indices"] = view_indices_conditioning.contiguous() |
|
data["frame_repeat"] = torch.zeros(len(view_indices)) |
|
|
|
data["fps"] = fps |
|
data["frame_start"] = frame_ids[0] |
|
data["frame_end"] = frame_ids[-1] + 1 |
|
data["num_frames"] = self.sequence_length |
|
data["image_size"] = torch.tensor([704, 1280, 704, 1280]) |
|
data["padding_mask"] = torch.zeros(1, 704, 1280) |
|
data[self.ctrl_type] = dict() |
|
data[self.ctrl_type]["video"] = ctrl_videos |
|
|
|
|
|
|
|
|
|
for _, aug_fn in self.augmentor.items(): |
|
data = aug_fn(data) |
|
|
|
return data |
|
|
|
except Exception: |
|
warnings.warn( |
|
f"Invalid data encountered: {self.video_paths[index]}. Skipped " |
|
f"(by randomly sampling another sample in the same dataset)." |
|
) |
|
warnings.warn("FULL TRACEBACK:") |
|
warnings.warn(traceback.format_exc()) |
|
if _ == max_retries - 1: |
|
raise RuntimeError(f"Failed to load data after {max_retries} attempts") |
|
index = np.random.randint(len(self.video_paths)) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
Sanity check for the dataset. |
|
""" |
|
control_input_key = "control_input_lidar" |
|
visualize_control_input = True |
|
|
|
dataset = AVTransferDataset( |
|
dataset_dir="datasets/waymo_transfer1", |
|
view_keys=["pinhole_front"], |
|
hint_key=control_input_key, |
|
num_frames=121, |
|
resolution="720", |
|
is_train=True, |
|
) |
|
print("finished init dataset") |
|
indices = [0, 12, 100, -1] |
|
for idx in indices: |
|
data = dataset[idx] |
|
print( |
|
( |
|
f"{idx=} " |
|
f"{data['frame_start']=}\n" |
|
f"{data['frame_end']=}\n" |
|
f"{data['video'].sum()=}\n" |
|
f"{data['video'].shape=}\n" |
|
f"{data[control_input_key].shape=}\n" |
|
f"{data['video_name']=}\n" |
|
f"{data['t5_text_embeddings'].shape=}\n" |
|
"---" |
|
) |
|
) |
|
if visualize_control_input: |
|
import imageio |
|
|
|
control_input_tensor = data[control_input_key].permute(1, 2, 3, 0).cpu().numpy() |
|
video_name = f"{control_input_key}.mp4" |
|
imageio.mimsave(video_name, control_input_tensor, fps=24) |
|
|