roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
from cosmos_predict1.utils.distributed import rank0_first
from cosmos_predict1.utils.misc import load_from_s3_with_cache
class BaseVAE(torch.nn.Module, ABC):
"""
Abstract base class for a Variational Autoencoder (VAE).
All subclasses should implement the methods to define the behavior for encoding
and decoding, along with specifying the latent channel size.
"""
def __init__(self, channel: int = 3, name: str = "vae"):
super().__init__()
self.channel = channel
self.name = name
@property
def latent_ch(self) -> int:
"""
Returns the number of latent channels in the VAE.
"""
return self.channel
@abstractmethod
def encode(self, state: torch.Tensor) -> torch.Tensor:
"""
Encodes the input tensor into a latent representation.
Args:
- state (torch.Tensor): The input tensor to encode.
Returns:
- torch.Tensor: The encoded latent tensor.
"""
pass
@abstractmethod
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""
Decodes the latent representation back to the original space.
Args:
- latent (torch.Tensor): The latent tensor to decode.
Returns:
- torch.Tensor: The decoded tensor.
"""
pass
@property
def spatial_compression_factor(self) -> int:
"""
Returns the spatial reduction factor for the VAE.
"""
raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.")
class BasePretrainedImageVAE(BaseVAE):
"""
A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values
from a remote store, handles data type conversions, and normalization
using provided mean and standard deviation values for latent space representation.
Derived classes should load pre-trained encoder and decoder components from a remote store
Attributes:
latent_mean (Tensor): The mean used for normalizing the latent representation.
latent_std (Tensor): The standard deviation used for normalizing the latent representation.
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
Args:
mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
latent_ch (int, optional): Number of latent channels (default is 16).
is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
"""
def __init__(
self,
name: str,
mean_std_fp: str,
latent_ch: int = 16,
is_image: bool = True,
is_bf16: bool = True,
) -> None:
super().__init__(latent_ch, name)
dtype = torch.bfloat16 if is_bf16 else torch.float32
self.dtype = dtype
self.is_image = is_image
self.mean_std_fp = mean_std_fp
self.name = name
self.backend_args = None
self.register_mean_std(mean_std_fp)
def register_mean_std(self, mean_std_fp: str) -> None:
latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True)
target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1]
self.register_buffer(
"latent_mean",
latent_mean.to(self.dtype).reshape(*target_shape),
persistent=False,
)
self.register_buffer(
"latent_std",
latent_std.to(self.dtype).reshape(*target_shape),
persistent=False,
)
@torch.no_grad()
def encode(self, state: torch.Tensor) -> torch.Tensor:
"""
Encode the input state to latent space; also handle the dtype conversion, mean and std scaling
"""
in_dtype = state.dtype
latent_mean = self.latent_mean.to(in_dtype)
latent_std = self.latent_std.to(in_dtype)
encoded_state = self.encoder(state.to(self.dtype))
if isinstance(encoded_state, torch.Tensor):
pass
elif isinstance(encoded_state, tuple):
assert isinstance(encoded_state[0], torch.Tensor)
encoded_state = encoded_state[0]
else:
raise ValueError("Invalid type of encoded state")
return (encoded_state.to(in_dtype) - latent_mean) / latent_std
@torch.no_grad()
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""
Decode the input latent to state; also handle the dtype conversion, mean and std scaling
"""
in_dtype = latent.dtype
latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype)
return self.decoder(latent.to(self.dtype)).to(in_dtype)
def reset_dtype(self, *args, **kwargs):
"""
Resets the data type of the encoder and decoder to the model's default data type.
Args:
*args, **kwargs: Unused, present to allow flexibility in method calls.
"""
del args, kwargs
self.decoder.to(self.dtype)
self.encoder.to(self.dtype)
class JITVAE(BasePretrainedImageVAE):
"""
A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder
and decoder components from a remote store, handles data type conversions, and normalization
using provided mean and standard deviation values for latent space representation.
Attributes:
encoder (Module): The JIT compiled encoder loaded from storage.
decoder (Module): The JIT compiled decoder loaded from storage.
latent_mean (Tensor): The mean used for normalizing the latent representation.
latent_std (Tensor): The standard deviation used for normalizing the latent representation.
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
Args:
enc_fp (str): File path to the encoder's JIT file on the remote store.
dec_fp (str): File path to the decoder's JIT file on the remote store.
name (str): Name of the model, used for differentiating cache file paths.
mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
latent_ch (int, optional): Number of latent channels (default is 16).
is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
"""
def __init__(
self,
enc_fp: str,
dec_fp: str,
name: str,
mean_std_fp: str,
latent_ch: int = 16,
is_image: bool = True,
is_bf16: bool = True,
):
super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16)
self.load_encoder(enc_fp)
self.load_decoder(dec_fp)
def load_encoder(self, enc_fp: str) -> None:
"""
Load the encoder from the remote store.
Args:
- enc_fp (str): File path to the encoder's JIT file on the remote store.
"""
self.encoder = torch.jit.load(enc_fp, map_location="cuda")
self.encoder.eval()
for param in self.encoder.parameters():
param.requires_grad = False
self.encoder.to(self.dtype)
def load_decoder(self, dec_fp: str) -> None:
"""
Load the decoder from the remote store.
Args:
- dec_fp (str): File path to the decoder's JIT file on the remote store.
"""
self.decoder = torch.jit.load(dec_fp, map_location="cuda")
self.decoder.eval()
for param in self.decoder.parameters():
param.requires_grad = False
self.decoder.to(self.dtype)
class StateDictVAE(BasePretrainedImageVAE):
"""
A Variational Autoencoder (VAE) that loads pre-trained weights into
provided encoder and decoder components from a remote store, handles data type conversions,
and normalization using provided mean and standard deviation values for latent space representation.
Attributes:
encoder (Module): The encoder with weights loaded from storage.
decoder (Module): The decoder with weights loaded from storage.
latent_mean (Tensor): The mean used for normalizing the latent representation.
latent_std (Tensor): The standard deviation used for normalizing the latent representation.
dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
Args:
enc_fp (str): File path to the encoder's JIT file on the remote store.
dec_fp (str): File path to the decoder's JIT file on the remote store.
vae (Module): Instance of VAE with not loaded weights
name (str): Name of the model, used for differentiating cache file paths.
mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
latent_ch (int, optional): Number of latent channels (default is 16).
is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
"""
def __init__(
self,
enc_fp: str,
dec_fp: str,
vae: torch.nn.Module,
name: str,
mean_std_fp: str,
latent_ch: int = 16,
is_image: bool = True,
is_bf16: bool = True,
):
super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16)
self.load_encoder_and_decoder(enc_fp, dec_fp, vae)
def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None:
"""
Load the encoder from the remote store.
Args:
- vae_fp (str): File path to the vae's state dict file on the remote store.
- vae (str): VAE module into which weights will be loaded.
"""
state_dict_enc = load_from_s3_with_cache(
enc_fp,
f"vae/{self.name}_enc.jit",
easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
backend_args=self.backend_args,
)
state_dict_dec = load_from_s3_with_cache(
dec_fp,
f"vae/{self.name}_dec.jit",
easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
backend_args=self.backend_args,
)
jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict()
jit_weights_state_dict = {
k: v
for k, v in jit_weights_state_dict.items()
# Global variables captured by JIT
if k
not in (
"encoder.patcher.wavelets",
"encoder.patcher._arange",
"decoder.unpatcher.wavelets",
"decoder.unpatcher._arange",
)
}
vae.load_state_dict(jit_weights_state_dict)
vae.eval()
for param in vae.parameters():
param.requires_grad = False
vae.to(self.dtype)
self.vae = vae
self.encoder = self.vae.encode
self.decoder = self.vae.decode
def reset_dtype(self, *args, **kwargs):
"""
Resets the data type of the encoder and decoder to the model's default data type.
Args:
*args, **kwargs: Unused, present to allow flexibility in method calls.
"""
del args, kwargs
self.vae.to(self.dtype)
class SDVAE(BaseVAE):
def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None:
super().__init__(channel=4, name="sd_vae")
self.dtype = torch.bfloat16
self.register_buffer(
"scale",
torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1),
persistent=False,
)
self.register_buffer(
"bias",
-1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale,
persistent=False,
)
self.batch_size = batch_size
self.count_std = count_std
self.is_downsample = is_downsample
self.load_vae()
self.reset_dtype()
def reset_dtype(self, *args, **kwargs):
del args, kwargs
self.vae.to(self.dtype)
@rank0_first
def load_vae(self) -> None:
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
import diffusers
vae_name = "stabilityai/sd-vae-ft-mse"
try:
vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True)
except: # noqa: E722
# Could not load the model from cache; try without local_files_only.
vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name)
self.vae = vae.eval().requires_grad_(False)
@torch.no_grad()
def encode(self, state: torch.Tensor) -> torch.Tensor:
"""
state : pixel range [-1, 1]
"""
if self.is_downsample:
_h, _w = state.shape[-2:]
state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False)
in_dtype = state.dtype
state = state.to(self.dtype)
state = (state + 1.0) / 2.0
latent_dist = self.vae.encode(state)["latent_dist"]
mean, std = latent_dist.mean, latent_dist.std
if self.count_std:
latent = mean + torch.randn_like(mean) * std
else:
latent = mean
latent = latent * self.scale
latent = latent + self.bias
return latent.to(in_dtype)
@torch.no_grad()
def decode(self, latent: torch.Tensor) -> torch.Tensor:
in_dtype = latent.dtype
latent = latent.to(self.dtype)
latent = latent - self.bias
latent = latent / self.scale
latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)])
if self.is_downsample:
_h, _w = latent.shape[-2:]
latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False)
return latent.to(in_dtype) * 2 - 1.0
@property
def spatial_compression_factor(self) -> int:
return 8