roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data loading utilities for the distributed format:
- RGB from mp4
- Depth from float16 numpy
- Camera data from float32 numpy
"""
import os
import numpy as np
import torch
import cv2
from pathlib import Path
def load_rgb_from_mp4(video_path):
"""
Load RGB video from mp4 file and convert to tensor.
Args:
video_path: str, path to the mp4 file
Returns:
torch.Tensor: RGB tensor of shape [T, C, H, W] with range [-1, 1]
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Failed to open video file: {video_path}")
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame_rgb)
cap.release()
if not frames:
raise ValueError(f"No frames found in video: {video_path}")
# Convert to numpy array and then tensor
frames_np = np.stack(frames, axis=0) # [T, H, W, C]
frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() # [T, C, H, W]
# Convert from [0, 255] to [-1, 1]
frames_tensor = (frames_tensor / 127.5) - 1.0
return frames_tensor
def load_depth_from_numpy(depth_path):
"""
Load depth data from compressed NPZ file.
Args:
depth_path: str, path to the NPZ file
Returns:
torch.Tensor: Depth tensor of shape [T, 1, H, W]
"""
data = np.load(depth_path)
depth_np = data['depth'] # [T, H, W]
depth_tensor = torch.from_numpy(depth_np.astype(np.float32))
# Add channel dimension: [T, H, W] -> [T, 1, H, W]
depth_tensor = depth_tensor.unsqueeze(1)
return depth_tensor
def load_mask_from_numpy(mask_path):
"""
Load mask data from compressed NPZ file.
Args:
mask_path: str, path to the NPZ file
Returns:
torch.Tensor: Mask tensor of shape [T, 1, H, W]
"""
data = np.load(mask_path)
mask_np = data['mask'] # [T, H, W] as bool
mask_tensor = torch.from_numpy(mask_np.astype(np.float32)) # Convert bool to float32
# Add channel dimension: [T, H, W] -> [T, 1, H, W]
mask_tensor = mask_tensor.unsqueeze(1)
return mask_tensor
def load_camera_from_numpy(data_dir):
"""
Load camera parameters from compressed NPZ file.
Args:
data_dir: str, directory containing camera.npz
Returns:
tuple: (w2c_tensor, intrinsics_tensor)
- w2c_tensor: torch.Tensor of shape [T, 4, 4]
- intrinsics_tensor: torch.Tensor of shape [T, 3, 3]
"""
camera_path = os.path.join(data_dir, "camera.npz")
if not os.path.exists(camera_path):
raise FileNotFoundError(f"camera file not found: {camera_path}")
data = np.load(camera_path)
w2c_np = data['w2c']
intrinsics_np = data['intrinsics']
w2c_tensor = torch.from_numpy(w2c_np)
intrinsics_tensor = torch.from_numpy(intrinsics_np)
return w2c_tensor, intrinsics_tensor
def load_data_distributed_format(data_dir):
"""Load data from distributed format (mp4 + numpy files)"""
data_path = Path(data_dir)
# Load RGB from mp4
cap = cv2.VideoCapture(str(data_path / "rgb.mp4"))
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
frames_np = np.stack(frames, axis=0)
image_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float()
image_tensor = (image_tensor / 127.5) - 1.0 # [0,255] -> [-1,1]
# Load depth and mask
depth_tensor = torch.from_numpy(np.load(data_path / "depth.npz")['depth'].astype(np.float32)).unsqueeze(1)
mask_tensor = torch.from_numpy(np.load(data_path / "mask.npz")['mask'].astype(np.float32)).unsqueeze(1)
# Load camera data
camera_data = np.load(data_path / "camera.npz")
w2c_tensor = torch.from_numpy(camera_data['w2c'])
intrinsics_tensor = torch.from_numpy(camera_data['intrinsics'])
return image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor
def load_data_packaged_format(pt_path):
"""
Load data from the packaged pt format for backward compatibility.
Args:
pt_path: str, path to the pt file
Returns:
tuple: (image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor)
"""
data = torch.load(pt_path)
if len(data) != 5:
raise ValueError(f"Expected 5 tensors in pt file, got {len(data)}")
return data
def load_data_auto_detect(input_path):
"""Auto-detect format and load data"""
input_path = Path(input_path)
if input_path.is_file() and input_path.suffix == '.pt':
return load_data_packaged_format(input_path)
elif input_path.is_dir():
return load_data_distributed_format(input_path)
else:
raise ValueError(f"Invalid input path: {input_path}")