roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library for Causal Video Tokenizer inference."""
from typing import Any
import numpy as np
import torch
from tqdm import tqdm
from cosmos_predict1.tokenizer.inference.utils import (
load_decoder_model,
load_encoder_model,
load_model,
numpy2tensor,
pad_video_batch,
tensor2numpy,
unpad_video_batch,
)
class CausalVideoTokenizer(torch.nn.Module):
def __init__(
self,
checkpoint: str = None,
checkpoint_enc: str = None,
checkpoint_dec: str = None,
tokenizer_config: dict[str, Any] = None,
device: str = "cuda",
dtype: str = "bfloat16",
) -> None:
super().__init__()
self._device = device
self._dtype = getattr(torch, dtype)
self._full_model = (
load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None
)
self._enc_model = (
load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
if checkpoint_enc is not None
else None
)
self._dec_model = (
load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
if checkpoint_dec is not None
else None
)
@torch.no_grad()
def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Reconstrcuts a batch of video tensors after embedding into a latent.
Args:
video: The input video Bx3xTxHxW layout, range [-1..1].
Returns:
The reconstructed video, layout Bx3xTxHxW, range [-1..1].
"""
if self._full_model is not None:
output_tensor = self._full_model(input_tensor)
output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
else:
output_latent = self.encode(input_tensor)[0]
output_tensor = self.decode(output_latent)
return output_tensor
@torch.no_grad()
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
"""Encodes a numpy video into a CausalVideo latent or code.
Args:
input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1].
Returns:
For causal continuous video (CV) tokenizer, the tuple contains:
- The latent embedding, Bx16x(t)x(h)x(w), where the compression
rate is (T/t x H/h x W/w), and channel dimension of 16.
For causal discrete video (DV) tokenizer, the tuple contains:
1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which
is formed by FSQ levels of (8,8,8,5,5,5).
2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate
is again (T/t x H/h x W/w), and channel dimension of 6.
"""
assert input_tensor.ndim == 5, "input video should be of 5D."
output_latent = self._enc_model(input_tensor)
if isinstance(output_latent, torch.Tensor):
return output_latent
return output_latent[:-1]
@torch.no_grad()
def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
"""Encodes a numpy video into a CausalVideo latent.
Args:
input_latent: The continuous latent Bx16xtxhxw for CV,
or the discrete indices Bxtxhxw for DV.
Returns:
The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1].
"""
assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete."
return self._dec_model(input_latent)
def forward(
self,
video: np.ndarray,
temporal_window: int = 17,
) -> np.ndarray:
"""Reconstructs video using a pre-trained CausalTokenizer autoencoder.
Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer
in a sliding manner with a `temporal_window` size.
Args:
video: The input video BxTxHxWx3 layout, range [0..255].
temporal_window: The length of the temporal window to process, default=25.
Returns:
The reconstructed video in range [0..255], layout BxTxHxWx3.
"""
assert video.ndim == 5, "input video should be of 5D."
num_frames = video.shape[1] # can be of any length.
output_video_list = []
for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)):
# Input video for the current window.
start, end = idx * temporal_window, (idx + 1) * temporal_window
input_video = video[:, start:end, ...]
# Spatio-temporally pad input_video so it's evenly divisible.
padded_input_video, crop_region = pad_video_batch(input_video)
input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device)
output_tensor = self.autoencode(input_tensor)
padded_output_video = tensor2numpy(output_tensor)
output_video = unpad_video_batch(padded_output_video, crop_region)
output_video_list.append(output_video)
return np.concatenate(output_video_list, axis=1)