|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A library for Causal Video Tokenizer inference.""" |
|
|
|
from typing import Any |
|
|
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from cosmos_transfer1.auxiliary.tokenizer.inference.utils import ( |
|
load_decoder_model, |
|
load_encoder_model, |
|
load_model, |
|
numpy2tensor, |
|
pad_video_batch, |
|
tensor2numpy, |
|
unpad_video_batch, |
|
) |
|
|
|
|
|
class CausalVideoTokenizer(torch.nn.Module): |
|
def __init__( |
|
self, |
|
checkpoint: str = None, |
|
checkpoint_enc: str = None, |
|
checkpoint_dec: str = None, |
|
tokenizer_config: dict[str, Any] = None, |
|
device: str = "cuda", |
|
dtype: str = "bfloat16", |
|
) -> None: |
|
super().__init__() |
|
self._device = device |
|
self._dtype = getattr(torch, dtype) |
|
self._full_model = ( |
|
load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None |
|
) |
|
self._enc_model = ( |
|
load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) |
|
if checkpoint_enc is not None |
|
else None |
|
) |
|
self._dec_model = ( |
|
load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) |
|
if checkpoint_dec is not None |
|
else None |
|
) |
|
|
|
@torch.no_grad() |
|
def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: |
|
"""Reconstrcuts a batch of video tensors after embedding into a latent. |
|
|
|
Args: |
|
video: The input video Bx3xTxHxW layout, range [-1..1]. |
|
Returns: |
|
The reconstructed video, layout Bx3xTxHxW, range [-1..1]. |
|
""" |
|
if self._full_model is not None: |
|
output_tensor = self._full_model(input_tensor) |
|
output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor |
|
else: |
|
output_latent = self.encode(input_tensor)[0] |
|
output_tensor = self.decode(output_latent) |
|
return output_tensor |
|
|
|
@torch.no_grad() |
|
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: |
|
"""Encodes a numpy video into a CausalVideo latent or code. |
|
|
|
Args: |
|
input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. |
|
Returns: |
|
For causal continuous video (CV) tokenizer, the tuple contains: |
|
- The latent embedding, Bx16x(t)x(h)x(w), where the compression |
|
rate is (T/t x H/h x W/w), and channel dimension of 16. |
|
For causal discrete video (DV) tokenizer, the tuple contains: |
|
1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which |
|
is formed by FSQ levels of (8,8,8,5,5,5). |
|
2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate |
|
is again (T/t x H/h x W/w), and channel dimension of 6. |
|
""" |
|
assert input_tensor.ndim == 5, "input video should be of 5D." |
|
|
|
output_latent = self._enc_model(input_tensor) |
|
if isinstance(output_latent, torch.Tensor): |
|
return output_latent |
|
return output_latent[:-1] |
|
|
|
@torch.no_grad() |
|
def decode(self, input_latent: torch.Tensor) -> torch.Tensor: |
|
"""Encodes a numpy video into a CausalVideo latent. |
|
|
|
Args: |
|
input_latent: The continuous latent Bx16xtxhxw for CV, |
|
or the discrete indices Bxtxhxw for DV. |
|
Returns: |
|
The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. |
|
""" |
|
assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." |
|
return self._dec_model(input_latent) |
|
|
|
def forward( |
|
self, |
|
video: np.ndarray, |
|
temporal_window: int = 17, |
|
) -> np.ndarray: |
|
"""Reconstructs video using a pre-trained CausalTokenizer autoencoder. |
|
Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer |
|
in a sliding manner with a `temporal_window` size. |
|
|
|
Args: |
|
video: The input video BxTxHxWx3 layout, range [0..255]. |
|
temporal_window: The length of the temporal window to process, default=25. |
|
Returns: |
|
The reconstructed video in range [0..255], layout BxTxHxWx3. |
|
""" |
|
assert video.ndim == 5, "input video should be of 5D." |
|
num_frames = video.shape[1] |
|
output_video_list = [] |
|
for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): |
|
|
|
start, end = idx * temporal_window, (idx + 1) * temporal_window |
|
input_video = video[:, start:end, ...] |
|
|
|
|
|
padded_input_video, crop_region = pad_video_batch(input_video) |
|
input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) |
|
output_tensor = self.autoencode(input_tensor) |
|
padded_output_video = tensor2numpy(output_tensor) |
|
output_video = unpad_video_batch(padded_output_video, crop_region) |
|
|
|
output_video_list.append(output_video) |
|
return np.concatenate(output_video_list, axis=1) |
|
|