Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import json | |
import math | |
import os | |
from pathlib import Path | |
from typing import List | |
import numpy as np | |
import torch | |
import torchvision | |
from PIL import Image | |
from cosmos_predict1.autoregressive.configs.inference.inference_config import SamplingConfig | |
from cosmos_predict1.utils import log | |
_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] | |
_VIDEO_EXTENSIONS = [".mp4"] | |
_SUPPORTED_CONTEXT_LEN = [1, 9] # Input frames | |
NUM_TOTAL_FRAMES = 33 | |
def add_common_arguments(parser): | |
"""Add common command line arguments. | |
Args: | |
parser (ArgumentParser): Argument parser to add arguments to | |
""" | |
parser.add_argument( | |
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" | |
) | |
parser.add_argument( | |
"--video_save_name", | |
type=str, | |
default="output", | |
help="Output filename for generating a single video", | |
) | |
parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos") | |
parser.add_argument( | |
"--input_image_or_video_path", | |
type=str, | |
help="Input path for input image or video", | |
) | |
parser.add_argument( | |
"--batch_input_path", | |
type=str, | |
help="Input folder containing all input images or videos", | |
) | |
parser.add_argument( | |
"--num_input_frames", | |
type=int, | |
default=9, | |
help="Number of input frames for world generation", | |
choices=_SUPPORTED_CONTEXT_LEN, | |
) | |
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") | |
parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling") | |
parser.add_argument("--seed", type=int, default=0, help="Random seed") | |
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") | |
parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") | |
parser.add_argument( | |
"--offload_guardrail_models", | |
action="store_true", | |
help="Offload guardrail models after inference", | |
) | |
parser.add_argument( | |
"--offload_diffusion_decoder", | |
action="store_true", | |
help="Offload diffusion decoder after inference", | |
) | |
parser.add_argument( | |
"--offload_ar_model", | |
action="store_true", | |
help="Offload AR model after inference", | |
) | |
parser.add_argument( | |
"--offload_tokenizer", | |
action="store_true", | |
help="Offload discrete tokenizer model after inference", | |
) | |
parser.add_argument( | |
"--disable_guardrail", | |
action="store_true", | |
help="Disable guardrail models", | |
) | |
def validate_args(args: argparse.Namespace, inference_type: str): | |
"""Validate command line arguments for base and video2world generation.""" | |
assert inference_type in [ | |
"base", | |
"video2world", | |
], "Invalid inference_type, must be 'base' or 'video2world'" | |
if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1: | |
args.num_input_frames = 1 | |
log.info(f"Set num_input_frames to 1 for {args.input_type} input") | |
if args.num_input_frames == 1: | |
if "4B" in args.ar_model_dir: | |
log.warning( | |
"The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." | |
) | |
elif "5B" in args.ar_model_dir: | |
log.warning( | |
"The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." | |
) | |
# Validate prompt/image/video args for single or batch generation | |
assert ( | |
args.input_image_or_video_path or args.batch_input_path | |
), "--input_image_or_video_path or --batch_input_path must be provided." | |
if inference_type == "video2world" and (not args.batch_input_path): | |
assert args.prompt, "--prompt is required for single video generation." | |
args.data_resolution = [640, 1024] | |
# Create output folder | |
Path(args.video_save_folder).mkdir(parents=True, exist_ok=True) | |
sampling_config = SamplingConfig( | |
echo=True, | |
temperature=args.temperature, | |
top_p=args.top_p, | |
compile_sampling=True, | |
) | |
return sampling_config | |
def resize_input(video: torch.Tensor, resolution: list[int]): | |
r""" | |
Function to perform aspect ratio preserving resizing and center cropping. | |
This is needed to make the video into target resolution. | |
Args: | |
video (torch.Tensor): Input video tensor | |
resolution (list[int]): Data resolution | |
Returns: | |
Cropped video | |
""" | |
orig_h, orig_w = video.shape[2], video.shape[3] | |
target_h, target_w = resolution | |
scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) | |
resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) | |
video_resized = torchvision.transforms.functional.resize(video, resizing_shape) | |
video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) | |
return video_cropped | |
def load_image_from_list(flist, data_resolution: List[int]) -> dict: | |
""" | |
Function to load images from a list of image paths. | |
Args: | |
flist (List[str]): List of image paths | |
data_resolution (List[int]): Data resolution | |
Returns: | |
Dict containing input images | |
""" | |
all_videos = dict() | |
for img_path in flist: | |
ext = os.path.splitext(img_path)[1] | |
if ext in _IMAGE_EXTENSIONS: | |
# Read the image | |
img = Image.open(img_path) | |
# Convert to tensor | |
img = torchvision.transforms.functional.to_tensor(img) | |
static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1) | |
static_vid = static_vid * 2 - 1 | |
log.debug( | |
f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" | |
) | |
static_vid = resize_input(static_vid, data_resolution) | |
fname = os.path.basename(img_path) | |
all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0) | |
return all_videos | |
def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict: | |
""" | |
Function to read input images from a JSONL file. | |
Args: | |
batch_input_path (str): Path to JSONL file containing visual input paths | |
data_resolution (list[int]): Data resolution | |
Returns: | |
Dict containing input images | |
""" | |
# Read visual inputs from JSONL | |
flist = [] | |
with open(batch_input_path, "r") as f: | |
for line in f: | |
data = json.loads(line.strip()) | |
flist.append(data["visual_input"]) | |
return load_image_from_list(flist, data_resolution=data_resolution) | |
def read_input_image(input_path: str, data_resolution: List[int]) -> dict: | |
""" | |
Function to read input image. | |
Args: | |
input_path (str): Path to input image | |
data_resolution (List[int]): Data resolution | |
Returns: | |
Dict containing input image | |
""" | |
flist = [input_path] | |
return load_image_from_list(flist, data_resolution=data_resolution) | |
def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: | |
r""" | |
Function to read input videos. | |
Args: | |
batch_input_path (str): Path to JSONL file containing visual input paths | |
data_resolution (list[int]): Data resolution | |
Returns: | |
Dict containing input videos | |
""" | |
# Read visual inputs from JSONL | |
flist = [] | |
with open(batch_input_path, "r") as f: | |
for line in f: | |
data = json.loads(line.strip()) | |
flist.append(data["visual_input"]) | |
return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) | |
def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: | |
""" | |
Function to read input video. | |
Args: | |
input_path (str): Path to input video | |
data_resolution (List[int]): Data resolution | |
num_input_frames (int): Number of frames in context | |
Returns: | |
Dict containing input video | |
""" | |
flist = [input_path] | |
return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) | |
def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict: | |
""" | |
Function to load videos from a list of video paths. | |
Args: | |
flist (List[str]): List of video paths | |
data_resolution (List[int]): Data resolution | |
num_input_frames (int): Number of frames in context | |
Returns: | |
Dict containing input videos | |
""" | |
all_videos = dict() | |
for video_path in flist: | |
ext = os.path.splitext(video_path)[-1] | |
if ext in _VIDEO_EXTENSIONS: | |
video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec") | |
video = video.float() / 255.0 | |
video = video * 2 - 1 | |
# Resize the videos to the required dimension | |
nframes_in_video = video.shape[0] | |
if nframes_in_video < num_input_frames: | |
fname = os.path.basename(video_path) | |
log.warning( | |
f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping." | |
) | |
continue | |
video = video[-num_input_frames:, :, :, :] | |
# Pad the video to NUM_TOTAL_FRAMES (because the tokenizer expects inputs of NUM_TOTAL_FRAMES) | |
video = torch.cat( | |
(video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)), | |
dim=0, | |
) | |
video = video.permute(0, 3, 1, 2) | |
log.debug( | |
f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" | |
) | |
video = resize_input(video, data_resolution) | |
fname = os.path.basename(video_path) | |
all_videos[fname] = video.transpose(0, 1).unsqueeze(0) | |
return all_videos | |
def load_vision_input( | |
input_type: str, | |
batch_input_path: str, | |
input_image_or_video_path: str, | |
data_resolution: List[int], | |
num_input_frames: int, | |
): | |
""" | |
Function to load vision input. | |
Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model. | |
Args: | |
input_type (str): Type of input | |
batch_input_path (str): Folder containing input images or videos | |
input_image_or_video_path (str): Path to input image or video | |
data_resolution (List[int]): Data resolution | |
num_input_frames (int): Number of frames in context | |
Returns: | |
Dict containing input videos | |
""" | |
if batch_input_path: | |
log.info(f"Reading batch inputs from path: {batch_input_path}") | |
if input_type == "image" or input_type == "text_and_image": | |
input_videos = read_input_images(batch_input_path, data_resolution=data_resolution) | |
elif input_type == "video" or input_type == "text_and_video": | |
input_videos = read_input_videos( | |
batch_input_path, | |
data_resolution=data_resolution, | |
num_input_frames=num_input_frames, | |
) | |
else: | |
raise ValueError(f"Invalid input type {input_type}") | |
else: | |
if input_type == "image" or input_type == "text_and_image": | |
input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution) | |
elif input_type == "video" or input_type == "text_and_video": | |
input_videos = read_input_video( | |
input_image_or_video_path, | |
data_resolution=data_resolution, | |
num_input_frames=num_input_frames, | |
) | |
else: | |
raise ValueError(f"Invalid input type {input_type}") | |
return input_videos | |
def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]: | |
""" | |
Function to convert output tensors to numpy format for saving. | |
Args: | |
video_batch (List[torch.Tensor]): List of output tensors | |
Returns: | |
List of numpy arrays | |
""" | |
return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch] | |