|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import Tuple |
|
|
|
import cv2 |
|
import magic |
|
import numpy as np |
|
import torch |
|
|
|
from cosmos_transfer1.utils import log |
|
|
|
|
|
SUPPORTED_VIDEO_TYPES = { |
|
".mp4": "video/mp4", |
|
".mkv": "video/x-matroska", |
|
".mov": "video/quicktime", |
|
".avi": "video/x-msvideo", |
|
".webm": "video/webm", |
|
".flv": "video/x-flv", |
|
".wmv": "video/x-ms-wmv", |
|
} |
|
|
|
|
|
def video_to_tensor(video_path: str, output_path: str, normalize: bool = True) -> Tuple[torch.Tensor, float]: |
|
"""Convert an MP4 video file to a tensor and save it as a .pt file. |
|
Args: |
|
video_path (str): Path to input MP4 video file |
|
output_path (str): Path to save output .pt tensor file |
|
normalize (bool): Whether to normalize pixel values to [-1,1] range (default: True) |
|
|
|
Returns: |
|
Tuple[torch.Tensor, float]: Tuple containing: |
|
- Video tensor in shape [C,T,H,W] |
|
- Video FPS |
|
""" |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
raise ValueError(f"Failed to open video file: {video_path}") |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
ret, frame = cap.read() |
|
if not ret: |
|
raise ValueError(f"Failed to read frames from video: {video_path}") |
|
|
|
height, width = frame.shape[:2] |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
|
|
|
|
|
frames = [] |
|
|
|
|
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(frame) |
|
|
|
cap.release() |
|
|
|
log.info(f"frames: {len(frames)}") |
|
|
|
video_tensor = torch.from_numpy(np.array(frames)) |
|
log.info(f"video_tensor shape: {video_tensor.shape}") |
|
|
|
video_tensor = video_tensor.permute(3, 0, 1, 2) |
|
|
|
|
|
if normalize: |
|
video_tensor = video_tensor.float() / 127.5 - 1.0 |
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
torch.save(video_tensor, output_path) |
|
|
|
return video_tensor, fps |
|
|
|
|
|
def is_valid_video(file_path: str) -> bool: |
|
if not os.path.isfile(file_path): |
|
return False |
|
|
|
ext = os.path.splitext(file_path)[1].lower() |
|
expected_mime = SUPPORTED_VIDEO_TYPES.get(ext) |
|
|
|
if not expected_mime: |
|
return False |
|
|
|
|
|
detected_mime = magic.from_file(file_path, mime=True) |
|
|
|
return detected_mime == expected_mime |
|
|