# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A library for Causal Video Tokenizer inference.""" from typing import Any import numpy as np import torch from tqdm import tqdm from cosmos_predict1.tokenizer.inference.utils import ( load_decoder_model, load_encoder_model, load_model, numpy2tensor, pad_video_batch, tensor2numpy, unpad_video_batch, ) class CausalVideoTokenizer(torch.nn.Module): def __init__( self, checkpoint: str = None, checkpoint_enc: str = None, checkpoint_dec: str = None, tokenizer_config: dict[str, Any] = None, device: str = "cuda", dtype: str = "bfloat16", ) -> None: super().__init__() self._device = device self._dtype = getattr(torch, dtype) self._full_model = ( load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None ) self._enc_model = ( load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) if checkpoint_enc is not None else None ) self._dec_model = ( load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) if checkpoint_dec is not None else None ) @torch.no_grad() def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: """Reconstrcuts a batch of video tensors after embedding into a latent. Args: video: The input video Bx3xTxHxW layout, range [-1..1]. Returns: The reconstructed video, layout Bx3xTxHxW, range [-1..1]. """ if self._full_model is not None: output_tensor = self._full_model(input_tensor) output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor else: output_latent = self.encode(input_tensor)[0] output_tensor = self.decode(output_latent) return output_tensor @torch.no_grad() def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: """Encodes a numpy video into a CausalVideo latent or code. Args: input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. Returns: For causal continuous video (CV) tokenizer, the tuple contains: - The latent embedding, Bx16x(t)x(h)x(w), where the compression rate is (T/t x H/h x W/w), and channel dimension of 16. For causal discrete video (DV) tokenizer, the tuple contains: 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which is formed by FSQ levels of (8,8,8,5,5,5). 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate is again (T/t x H/h x W/w), and channel dimension of 6. """ assert input_tensor.ndim == 5, "input video should be of 5D." output_latent = self._enc_model(input_tensor) if isinstance(output_latent, torch.Tensor): return output_latent return output_latent[:-1] @torch.no_grad() def decode(self, input_latent: torch.Tensor) -> torch.Tensor: """Encodes a numpy video into a CausalVideo latent. Args: input_latent: The continuous latent Bx16xtxhxw for CV, or the discrete indices Bxtxhxw for DV. Returns: The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. """ assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." return self._dec_model(input_latent) def forward( self, video: np.ndarray, temporal_window: int = 17, ) -> np.ndarray: """Reconstructs video using a pre-trained CausalTokenizer autoencoder. Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer in a sliding manner with a `temporal_window` size. Args: video: The input video BxTxHxWx3 layout, range [0..255]. temporal_window: The length of the temporal window to process, default=25. Returns: The reconstructed video in range [0..255], layout BxTxHxWx3. """ assert video.ndim == 5, "input video should be of 5D." num_frames = video.shape[1] # can be of any length. output_video_list = [] for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): # Input video for the current window. start, end = idx * temporal_window, (idx + 1) * temporal_window input_video = video[:, start:end, ...] # Spatio-temporally pad input_video so it's evenly divisible. padded_input_video, crop_region = pad_video_batch(input_video) input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) output_tensor = self.autoencode(input_tensor) padded_output_video = tensor2numpy(output_tensor) output_video = unpad_video_batch(padded_output_video, crop_region) output_video_list.append(output_video) return np.concatenate(output_video_list, axis=1)