harry900000's picture
remove printing of latent shape
1e9a1ec
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Set, Tuple
import torch
from torch import Tensor
from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, CosmosCondition
from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul
from cosmos_transfer1.diffusion.diffusion.modules.denoiser_scaling import EDMScaling
from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import Sampler
from cosmos_transfer1.diffusion.diffusion.types import DenoisePrediction
from cosmos_transfer1.diffusion.module import parallel
from cosmos_transfer1.diffusion.module.blocks import FourierFeatures
from cosmos_transfer1.diffusion.module.pretrained_vae import BaseVAE
from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT
from cosmos_transfer1.utils import log, misc
from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate
# from cosmos_transfer1.diffusion.training.models.model import _broadcast
IS_PREPROCESSED_KEY = "is_preprocessed"
from enum import Enum
class DataType(Enum):
IMAGE = "image"
VIDEO = "video"
MIX = "mix"
class EDMSDE:
def __init__(
self,
sigma_max: float,
sigma_min: float,
):
self.sigma_max = sigma_max
self.sigma_min = sigma_min
class DiffusionT2WModel(torch.nn.Module):
"""Text-to-world diffusion model that generates video frames from text descriptions.
This model implements a diffusion-based approach for generating videos conditioned on text input.
It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling,
and classifier-free guidance.
"""
def __init__(self, config):
"""Initialize the diffusion model.
Args:
config: Configuration object containing model parameters and architecture settings
"""
super().__init__()
# Initialize trained_data_record with defaultdict, key: image, video, iteration
self.config = config
self.precision = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[config.precision]
self.tensor_kwargs = {"device": "cuda", "dtype": self.precision}
log.debug(f"DiffusionModel: precision {self.precision}")
# Timer passed to network to detect slow ranks.
# 1. set data keys and data information
self.sigma_data = config.sigma_data
self.state_shape = list(config.latent_shape)
self.setup_data_key()
# 2. setup up diffusion processing and scaling~(pre-condition), sampler
self.sde = EDMSDE(sigma_max=80, sigma_min=0.0002)
self.sampler = Sampler()
self.scaling = EDMScaling(self.sigma_data)
self.tokenizer = None
self.model = None
@property
def net(self):
return self.model.net
@property
def conditioner(self):
return self.model.conditioner
@property
def logvar(self):
return self.model.logvar
def set_up_tokenizer(self, tokenizer_dir: str):
self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer)
self.tokenizer.load_weights(tokenizer_dir)
if hasattr(self.tokenizer, "reset_dtype"):
self.tokenizer.reset_dtype()
@misc.timer("DiffusionModel: set_up_model")
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
"""Initialize the core model components including network, conditioner and logvar."""
self.model = self.build_model()
self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs)
def build_model(self) -> torch.nn.ModuleDict:
"""Construct the model's neural network components.
Returns:
ModuleDict containing the network, conditioner and logvar components
"""
config = self.config
net = lazy_instantiate(config.net)
conditioner = lazy_instantiate(config.conditioner)
logvar = torch.nn.Sequential(
FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False)
)
return torch.nn.ModuleDict(
{
"net": net,
"conditioner": conditioner,
"logvar": logvar,
}
)
@torch.no_grad()
def encode(self, state: torch.Tensor) -> torch.Tensor:
"""Encode input state into latent representation using VAE.
Args:
state: Input tensor to encode
Returns:
Encoded latent representation scaled by sigma_data
"""
return self.tokenizer.encode(state) * self.sigma_data
@torch.no_grad()
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""Decode latent representation back to pixel space using VAE.
Args:
latent: Latent tensor to decode
Returns:
Decoded tensor in pixel space
"""
return self.tokenizer.decode(latent / self.sigma_data)
def setup_data_key(self) -> None:
"""Configure input data keys for video and image data."""
self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model
self.input_image_key = self.config.input_image_key
def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction:
"""
Performs denoising on the input noise data, noise level, and condition
Args:
xt (torch.Tensor): The input noise data.
sigma (torch.Tensor): The noise level.
condition (CosmosCondition): conditional information, generated from self.conditioner
Returns:
DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \
noise prediction (eps_pred) and optional confidence (logvar).
"""
xt = xt.to(**self.tensor_kwargs)
sigma = sigma.to(**self.tensor_kwargs)
# get precondition for the network
c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma)
# forward pass through the network
net_output = self.net(
x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
**condition.to_dict(),
)
logvar = self.model.logvar(c_noise)
x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output)
# get noise prediction based on sde
eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma)
return DenoisePrediction(x0_pred, eps_pred, logvar)
def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor:
"""
Perform a robust broadcast operation that works regardless of tensor shapes on different ranks.
Args:
tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks).
src (int): The source rank for the broadcast. Defaults to 0.
Returns:
torch.Tensor: The broadcasted tensor on all ranks.
"""
# First, broadcast the shape of the tensor
if distributed.get_rank() == src:
shape = torch.tensor(tensor.shape).cuda()
else:
shape = torch.empty(tensor.dim(), dtype=torch.long).cuda()
if is_check_shape:
_verify_param_shape_across_processes(pg, [shape])
torch.distributed.broadcast(shape, src, group=pg)
# Resize the tensor on non-src ranks if necessary
if distributed.get_rank() != src:
tensor = tensor.new_empty(shape.tolist()).type_as(tensor)
# Now broadcast the tensor data
torch.distributed.broadcast(tensor, src, group=pg)
return tensor
def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None:
"""
Broadcast the item from the minimum rank in the specified group(s).
Since global rank = tp_rank + cp_rank * tp_size + ...
First broadcast in the tp_group and then in the cp_group will
ensure that the item is broadcasted across ranks in cp_group and tp_group.
Parameters:
- item: The item to broadcast (can be a torch.Tensor, str, or None).
- to_tp: Whether to broadcast to the tensor model parallel group.
- to_cp: Whether to broadcast to the context parallel group.
"""
if not parallel_state.is_initialized():
return item
tp_group = parallel_state.get_tensor_model_parallel_group()
cp_group = parallel_state.get_context_parallel_group()
to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1
to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1
if to_tp:
min_tp_rank = min(get_process_group_ranks(tp_group))
if to_cp:
min_cp_rank = min(get_process_group_ranks(cp_group))
if isinstance(item, torch.Tensor): # assume the device is cuda
if to_tp:
item = robust_broadcast(item, min_tp_rank, tp_group)
if to_cp:
item = robust_broadcast(item, min_cp_rank, cp_group)
elif item is not None:
broadcastable_list = [item]
if to_tp:
broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group)
if to_cp:
broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group)
item = broadcastable_list[0]
return item
class DistillT2WModel(DiffusionT2WModel):
"""Base Video Distillation Model."""
def __init__(self, config):
super().__init__(config)
def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, CosmosCondition]:
self._normalize_video_databatch_inplace(data_batch)
self._augment_image_dim_inplace(data_batch)
input_key = self.input_data_key # by default it is video key
is_image_batch = self.is_image_batch(data_batch)
is_video_batch = not is_image_batch
# Broadcast data and condition across TP and CP groups.
# sort keys to make sure the order is same, IMPORTANT! otherwise, nccl will hang!
local_keys = sorted(list(data_batch.keys()))
for key in local_keys:
data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch)
if is_image_batch:
input_key = self.input_image_key
# Latent state
raw_state = data_batch[input_key]
latent_state = self.encode(raw_state).contiguous()
# Condition
condition = self.conditioner(data_batch)
if is_image_batch:
condition.data_type = DataType.IMAGE
else:
condition.data_type = DataType.VIDEO
# VAE has randomness. CP/TP group should have the same encoded output.
latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch)
condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch)
def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool:
"""We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch.
Another comes from a dataloader which we by default assumes as video_data for video model training.
"""
is_image = self.input_image_key in data_batch
is_video = self.input_data_key in data_batch
assert (
is_image != is_video
), "Only one of the input_image_key or input_data_key should be present in the data_batch."
return is_image
def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None:
"""
Normalizes video data in-place on a CUDA device to reduce data loading overhead.
This function modifies the video data tensor within the provided data_batch dictionary
in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1].
Warning:
A warning is issued if the data has not been previously normalized.
Args:
data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key.
This tensor is expected to be on a CUDA device and have dtype of torch.uint8.
Side Effects:
Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place.
Note:
This operation is performed directly on the CUDA device to avoid the overhead associated
with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device
and has the correct dtype (torch.uint8) to avoid unexpected behaviors.
"""
input_key = self.input_data_key if input_key is None else input_key
# only handle video batch
if input_key in data_batch:
# Check if the data has already been normalized and avoid re-normalizing
if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True:
assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format."
assert torch.all(
(data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001)
), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]"
else:
assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format."
data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0
data_batch[IS_PREPROCESSED_KEY] = True
def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None:
input_key = self.input_image_key if input_key is None else input_key
if input_key in data_batch:
# Check if the data has already been augmented and avoid re-augmenting
if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True:
assert (
data_batch[input_key].shape[2] == 1
), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}"
return
else:
data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous()
data_batch[IS_PREPROCESSED_KEY] = True
def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition:
condition_kwargs = {}
for k, v in condition.to_dict().items():
if isinstance(v, torch.Tensor):
assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it"
condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp)
condition = type(condition)(**condition_kwargs)
return condition