Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Data loading utilities for the distributed format: | |
- RGB from mp4 | |
- Depth from float16 numpy | |
- Camera data from float32 numpy | |
""" | |
import os | |
import numpy as np | |
import torch | |
import cv2 | |
from pathlib import Path | |
def load_rgb_from_mp4(video_path): | |
""" | |
Load RGB video from mp4 file and convert to tensor. | |
Args: | |
video_path: str, path to the mp4 file | |
Returns: | |
torch.Tensor: RGB tensor of shape [T, C, H, W] with range [-1, 1] | |
""" | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
raise RuntimeError(f"Failed to open video file: {video_path}") | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert BGR to RGB | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame_rgb) | |
cap.release() | |
if not frames: | |
raise ValueError(f"No frames found in video: {video_path}") | |
# Convert to numpy array and then tensor | |
frames_np = np.stack(frames, axis=0) # [T, H, W, C] | |
frames_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() # [T, C, H, W] | |
# Convert from [0, 255] to [-1, 1] | |
frames_tensor = (frames_tensor / 127.5) - 1.0 | |
return frames_tensor | |
def load_depth_from_numpy(depth_path): | |
""" | |
Load depth data from compressed NPZ file. | |
Args: | |
depth_path: str, path to the NPZ file | |
Returns: | |
torch.Tensor: Depth tensor of shape [T, 1, H, W] | |
""" | |
data = np.load(depth_path) | |
depth_np = data['depth'] # [T, H, W] | |
depth_tensor = torch.from_numpy(depth_np.astype(np.float32)) | |
# Add channel dimension: [T, H, W] -> [T, 1, H, W] | |
depth_tensor = depth_tensor.unsqueeze(1) | |
return depth_tensor | |
def load_mask_from_numpy(mask_path): | |
""" | |
Load mask data from compressed NPZ file. | |
Args: | |
mask_path: str, path to the NPZ file | |
Returns: | |
torch.Tensor: Mask tensor of shape [T, 1, H, W] | |
""" | |
data = np.load(mask_path) | |
mask_np = data['mask'] # [T, H, W] as bool | |
mask_tensor = torch.from_numpy(mask_np.astype(np.float32)) # Convert bool to float32 | |
# Add channel dimension: [T, H, W] -> [T, 1, H, W] | |
mask_tensor = mask_tensor.unsqueeze(1) | |
return mask_tensor | |
def load_camera_from_numpy(data_dir): | |
""" | |
Load camera parameters from compressed NPZ file. | |
Args: | |
data_dir: str, directory containing camera.npz | |
Returns: | |
tuple: (w2c_tensor, intrinsics_tensor) | |
- w2c_tensor: torch.Tensor of shape [T, 4, 4] | |
- intrinsics_tensor: torch.Tensor of shape [T, 3, 3] | |
""" | |
camera_path = os.path.join(data_dir, "camera.npz") | |
if not os.path.exists(camera_path): | |
raise FileNotFoundError(f"camera file not found: {camera_path}") | |
data = np.load(camera_path) | |
w2c_np = data['w2c'] | |
intrinsics_np = data['intrinsics'] | |
w2c_tensor = torch.from_numpy(w2c_np) | |
intrinsics_tensor = torch.from_numpy(intrinsics_np) | |
return w2c_tensor, intrinsics_tensor | |
def load_data_distributed_format(data_dir): | |
"""Load data from distributed format (mp4 + numpy files)""" | |
data_path = Path(data_dir) | |
# Load RGB from mp4 | |
cap = cv2.VideoCapture(str(data_path / "rgb.mp4")) | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
cap.release() | |
frames_np = np.stack(frames, axis=0) | |
image_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2).float() | |
image_tensor = (image_tensor / 127.5) - 1.0 # [0,255] -> [-1,1] | |
# Load depth and mask | |
depth_tensor = torch.from_numpy(np.load(data_path / "depth.npz")['depth'].astype(np.float32)).unsqueeze(1) | |
mask_tensor = torch.from_numpy(np.load(data_path / "mask.npz")['mask'].astype(np.float32)).unsqueeze(1) | |
# Load camera data | |
camera_data = np.load(data_path / "camera.npz") | |
w2c_tensor = torch.from_numpy(camera_data['w2c']) | |
intrinsics_tensor = torch.from_numpy(camera_data['intrinsics']) | |
return image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor | |
def load_data_packaged_format(pt_path): | |
""" | |
Load data from the packaged pt format for backward compatibility. | |
Args: | |
pt_path: str, path to the pt file | |
Returns: | |
tuple: (image_tensor, depth_tensor, mask_tensor, w2c_tensor, intrinsics_tensor) | |
""" | |
data = torch.load(pt_path) | |
if len(data) != 5: | |
raise ValueError(f"Expected 5 tensors in pt file, got {len(data)}") | |
return data | |
def load_data_auto_detect(input_path): | |
"""Auto-detect format and load data""" | |
input_path = Path(input_path) | |
if input_path.is_file() and input_path.suffix == '.pt': | |
return load_data_packaged_format(input_path) | |
elif input_path.is_dir(): | |
return load_data_distributed_format(input_path) | |
else: | |
raise ValueError(f"Invalid input path: {input_path}") |