Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional | |
import torch | |
from einops import rearrange | |
from cosmos_predict1.autoregressive.tokenizer.quantizers import FSQuantizer | |
# Make sure jit model output consistenly during consecutive calls | |
# Check here: https://github.com/pytorch/pytorch/issues/74534 | |
torch._C._jit_set_texpr_fuser_enabled(False) | |
def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: | |
"""Loads a torch.jit.ScriptModule from a filepath. | |
Args: | |
jit_filepath: The filepath to the JIT-compiled model. | |
device: The device to load the model onto, default=cuda. | |
Returns: | |
The JIT compiled model loaded to device and on eval mode. | |
""" | |
# Make sure jit model output consistenly during consecutive calls | |
# Check here: https://github.com/pytorch/pytorch/issues/74534 | |
torch._C._jit_set_texpr_fuser_enabled(False) | |
model = torch.jit.load(jit_filepath) | |
return model.eval().to(device) | |
class BaseDiscreteVideoFSQTokenizer(torch.nn.Module): | |
""" | |
A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization | |
using provided mean and standard deviation values for latent space representation. | |
Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes. | |
Attributes: | |
encoder (Module | Callable): Encoder loaded from storage. | |
decoder (Module | Callable): Decoder loaded from storage. | |
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. | |
Args: | |
name (str): Name of the model, used for differentiating cache file paths. | |
latent_ch (int, optional): Number of latent channels (default is 6). | |
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). | |
pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. | |
latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. | |
max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. | |
level (list[int]): The level defined in FSQ quantizer. | |
compression_ratio (list[int]): The compression factor for (T, H, W). | |
""" | |
def __init__( | |
self, | |
name: str, | |
latent_ch: int = 6, | |
is_bf16: bool = True, | |
pixel_chunk_duration: int = 25, | |
latent_chunk_duration: int = 4, | |
max_enc_batch_size: int = 8, | |
max_dec_batch_size: int = 4, | |
levels: list[int] = [8, 8, 8, 5, 5, 5], | |
compression_ratio: list[int] = [8, 16, 16], | |
): | |
super().__init__() | |
self.channel = latent_ch | |
self.name = name | |
dtype = torch.bfloat16 if is_bf16 else torch.float32 | |
self.dtype = dtype | |
self.pixel_chunk_duration = pixel_chunk_duration | |
self.latent_chunk_duration = latent_chunk_duration | |
self.max_enc_batch_size = max_enc_batch_size | |
self.max_dec_batch_size = max_dec_batch_size | |
self.levels = levels | |
self.compress_ratio = compression_ratio | |
self.fsq_quantizer = FSQuantizer(levels) | |
def latent_ch(self) -> int: | |
""" | |
Returns the number of latent channels in the tokenizer. | |
""" | |
return self.channel | |
def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: | |
B, C, T, H, W = state.shape | |
if pixel_chunk_duration is None: | |
# Use the default pixel chunk duration and latent chunk duration | |
pixel_chunk_duration = self.pixel_chunk_duration | |
latent_chunk_duration = self.latent_chunk_duration | |
else: | |
# Update the latent chunk duration based on the given pixel chunk duration | |
latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] | |
assert ( | |
T % pixel_chunk_duration == 0 | |
), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}" | |
state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration) | |
# use max_enc_batch_size to avoid OOM | |
if state.shape[0] > self.max_enc_batch_size: | |
quantized_out_list = [] | |
indices_list = [] | |
for i in range(0, state.shape[0], self.max_enc_batch_size): | |
indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype)) | |
quantized_out_list.append(quantized_out) | |
indices_list.append(indices) | |
quantized_out = torch.cat(quantized_out_list, dim=0) | |
indices = torch.cat(indices_list, dim=0) | |
else: | |
indices, quantized_out, _ = self.encoder(state.to(self.dtype)) | |
assert quantized_out.shape[2] == latent_chunk_duration | |
return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange( | |
indices, "(b n) t h w -> b (n t) h w", b=B | |
) | |
def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor: | |
B, T, _, _ = indices.shape | |
if pixel_chunk_duration is None: | |
pixel_chunk_duration = self.pixel_chunk_duration | |
latent_chunk_duration = self.latent_chunk_duration | |
else: | |
latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0] | |
assert ( | |
T % latent_chunk_duration == 0 | |
), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}" | |
indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration) | |
# use max_dec_batch_size to avoid OOM | |
if indices.shape[0] > self.max_dec_batch_size: | |
state = [] | |
for i in range(0, indices.shape[0], self.max_dec_batch_size): | |
state.append(self.decoder(indices[i : i + self.max_dec_batch_size])) | |
state = torch.cat(state, dim=0) | |
else: | |
state = self.decoder(indices) | |
assert state.shape[2] == pixel_chunk_duration | |
return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) | |
def reset_dtype(self, *args, **kwargs): | |
""" | |
Resets the data type of the encoder and decoder to the model's default data type. | |
Args: | |
*args, **kwargs: Unused, present to allow flexibility in method calls. | |
""" | |
del args, kwargs | |
self.decoder.to(self.dtype) | |
self.encoder.to(self.dtype) | |
class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer): | |
""" | |
A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder | |
and decoder components from a remote store, handles data type conversions, and normalization | |
using provided mean and standard deviation values for latent space representation. | |
Attributes: | |
encoder (Module): The JIT compiled encoder loaded from storage. | |
decoder (Module): The JIT compiled decoder loaded from storage. | |
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. | |
Args: | |
enc_fp (str): File path to the encoder's JIT file on the remote store. | |
dec_fp (str): File path to the decoder's JIT file on the remote store. | |
name (str): Name of the model, used for differentiating cache file paths. | |
latent_ch (int, optional): Number of latent channels (default is 6). | |
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). | |
pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. | |
latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. | |
max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. | |
level (list[int]): The level defined in FSQ quantizer. | |
compression_ratio (list[int]): The compression factor for (T, H, W). | |
""" | |
def __init__( | |
self, | |
enc_fp: str, | |
dec_fp: str, | |
name: str, | |
latent_ch: int = 6, | |
is_bf16: bool = True, | |
pixel_chunk_duration: int = 25, | |
latent_chunk_duration: int = 4, | |
max_enc_batch_size: int = 8, | |
max_dec_batch_size: int = 4, | |
levels: list[int] = [8, 8, 8, 5, 5, 5], | |
compression_ratio: list[int] = [8, 16, 16], | |
): | |
super().__init__( | |
name, | |
latent_ch, | |
is_bf16, | |
pixel_chunk_duration, | |
latent_chunk_duration, | |
max_enc_batch_size, | |
max_dec_batch_size, | |
levels, | |
compression_ratio, | |
) | |
self.load_encoder(enc_fp) | |
self.load_decoder(dec_fp) | |
def load_encoder(self, enc_fp: str) -> None: | |
""" | |
Load the encoder from the remote store. | |
Args: | |
- enc_fp (str): File path to the encoder's JIT file on the remote store. | |
""" | |
self.encoder = load_jit_model(enc_fp, device="cuda") | |
self.encoder.eval() | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
self.encoder.to(self.dtype) | |
def load_decoder(self, dec_fp: str) -> None: | |
""" | |
Load the decoder from the remote store. | |
Args: | |
- dec_fp (str): File path to the decoder's JIT file on the remote store. | |
""" | |
self.decoder = load_jit_model(dec_fp, device="cuda") | |
self.decoder.eval() | |
for param in self.decoder.parameters(): | |
param.requires_grad = False | |
self.decoder.to(self.dtype) | |
class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer): | |
""" | |
A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder | |
into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled, | |
handles data type conversions, and normalization using provided mean and standard deviation values for latent | |
space representation. | |
Attributes: | |
tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints | |
encoder (Callable): tokenizer_module's encode method | |
decoder (Callable): tokenizer_module's decode method | |
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. | |
Args: | |
enc_fp (str): File path to the encoder's JIT file on the remote store. | |
dec_fp (str): File path to the decoder's JIT file on the remote store. | |
tokenizer_module (Module): Tokenizer module that will have it's weights loaded | |
name (str): Name of the model, used for differentiating cache file paths. | |
latent_ch (int, optional): Number of latent channels (default is 6). | |
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). | |
pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. | |
latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level. | |
max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow. | |
level (list[int]): The level defined in FSQ quantizer. | |
compression_ratio (list[int]): The compression factor for (T, H, W). | |
""" | |
def __init__( | |
self, | |
enc_fp: str, | |
dec_fp: str, | |
tokenizer_module: torch.nn.Module, | |
name: str, | |
latent_ch: int = 6, | |
is_bf16: bool = True, | |
pixel_chunk_duration: int = 25, | |
latent_chunk_duration: int = 4, | |
max_enc_batch_size: int = 8, | |
max_dec_batch_size: int = 4, | |
levels: list[int] = [8, 8, 8, 5, 5, 5], | |
compression_ratio: list[int] = [8, 16, 16], | |
): | |
super().__init__( | |
name, | |
latent_ch, | |
is_bf16, | |
pixel_chunk_duration, | |
latent_chunk_duration, | |
max_enc_batch_size, | |
max_dec_batch_size, | |
levels, | |
compression_ratio, | |
) | |
self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module) | |
def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None: | |
""" | |
Load the encoder from the remote store. | |
Args: | |
- enc_fp (str): File path to the encoder's JIT file on the remote store. | |
- def_fp (str): File path to the decoder's JIT file on the remote store. | |
- tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints | |
""" | |
self.decoder = load_jit_model(dec_fp) | |
self.decoder.eval() | |
for param in self.decoder.parameters(): | |
param.requires_grad = False | |
self.decoder.to(self.dtype) | |
encoder_sd = load_jit_model(enc_fp).state_dict() | |
del tokenizer_module.post_quant_conv | |
del tokenizer_module.decoder | |
state_dict = { | |
k: v | |
for k, v in (encoder_sd).items() | |
# Variables captured by JIT | |
if k | |
not in ( | |
"encoder.patcher3d.wavelets", | |
"encoder.patcher3d._arange", | |
"encoder.patcher3d.patch_size_buffer", | |
"quantizer._levels", | |
"quantizer._basis", | |
"quantizer.implicit_codebook", | |
) | |
} | |
tokenizer_module.load_state_dict(state_dict) | |
tokenizer_module.eval() | |
for param in tokenizer_module.parameters(): | |
param.requires_grad = False | |
tokenizer_module.to(self.dtype) | |
self.tokenizer_module = tokenizer_module | |
self.encoder = self.tokenizer_module.encode | |
def reset_dtype(self, *args, **kwargs): | |
""" | |
Resets the data type of the encoder and decoder to the model's default data type. | |
Args: | |
*args, **kwargs: Unused, present to allow flexibility in method calls. | |
""" | |
del args, kwargs | |
self.decoder.to(self.dtype) | |
self.tokenizer_module.to(self.dtype) | |