# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch from megatron.core import parallel_state from torch import Tensor from cosmos_transfer1.diffusion.conditioner import VideoExtendCondition from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul from cosmos_transfer1.diffusion.model.model_t2w import DataType, DiffusionT2WModel, DistillT2WModel from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp from cosmos_transfer1.utils import log, misc @dataclass class VideoDenoisePrediction: x0: torch.Tensor # clean data prediction eps: Optional[torch.Tensor] = None # noise prediction logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent class DiffusionV2WModel(DiffusionT2WModel): def __init__(self, config): super().__init__(config) def augment_conditional_latent_frames( self, condition: VideoExtendCondition, cfg_video_cond_bool: VideoCondBoolConfig, gt_latent: Tensor, condition_video_augment_sigma_in_inference: float = 0.001, sigma: Tensor = None, seed: int = 1, ) -> Union[VideoExtendCondition, Tensor]: """Augments the conditional frames with noise during inference. Args: condition (VideoExtendCondition): condition object condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference sigma (Tensor): noise level for the generation region seed (int): random seed for reproducibility Returns: VideoExtendCondition: updated condition object condition_video_augment_sigma: sigma for the condition region, feed to the network augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W """ # Inference only, use fixed sigma for the condition region assert ( condition_video_augment_sigma_in_inference is not None ), "condition_video_augment_sigma_in_inference should be provided" augment_sigma = condition_video_augment_sigma_in_inference if augment_sigma >= sigma.flatten()[0]: # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. # This is achieved by setting all region as `generation`, i.e. value=0 log.debug("augment_sigma larger than sigma or other frame, remove condition") condition.condition_video_indicator = condition.condition_video_indicator * 0 B = gt_latent.shape[0] augment_sigma = torch.full((B,), augment_sigma, **self.tensor_kwargs) # Now apply the augment_sigma to the gt_latent noise = misc.arch_invariant_rand( gt_latent.shape, torch.float32, self.tensor_kwargs["device"], seed, ) augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None] _, _, c_in_augment, _ = self.scaling(sigma=augment_sigma) # Multiply the whole latent with c_in_augment augment_latent_cin = batch_mul(augment_latent, c_in_augment) # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect _, _, c_in, _ = self.scaling(sigma=sigma) augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) return condition, augment_latent_cin def denoise( self, noise_x: Tensor, sigma: Tensor, condition: VideoExtendCondition, condition_video_augment_sigma_in_inference: float = 0.001, seed: int = 1, ) -> VideoDenoisePrediction: """Denoises input tensor using conditional video generation. Args: noise_x (Tensor): Noisy input tensor. sigma (Tensor): Noise level. condition (VideoExtendCondition): Condition for denoising. condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference seed (int): Random seed for reproducibility Returns: VideoDenoisePrediction containing: - x0: Denoised prediction - eps: Noise prediction - logvar: Log variance of noise prediction - xt: Input before c_in multiplication - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth """ assert ( condition.gt_latent is not None ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" gt_latent = condition.gt_latent cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool condition_latent = gt_latent # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed condition, augment_latent = self.augment_conditional_latent_frames( condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed ) condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] if parallel_state.get_context_parallel_world_size() > 1: cp_group = parallel_state.get_context_parallel_group() condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) # Compose the model input with condition region (augment_latent) and generation region (noise_x) new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x # Call the abse model denoise_pred = super().denoise(new_noise_xt, sigma, condition) x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 x0_pred = x0_pred_replaced return VideoDenoisePrediction( x0=x0_pred, eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), logvar=denoise_pred.logvar, xt=new_noise_xt, x0_pred_replaced=x0_pred_replaced, ) CHUNKING_MODE = "rand_order" # ["shuffle", "in_order", "rand_order"] IS_STAGGERED = True def get_chunks_indices(self, total_flen, chunking_size) -> List[torch.Tensor]: chunks_indices = [] if self.CHUNKING_MODE == "shuffle": for index in torch.arange(0, total_flen, 1).split(chunking_size): chunks_indices.append(index) np.random.shuffle(chunks_indices) else: first_chunk_end = ( int(torch.randint(low=0, high=chunking_size, size=(1,)) + 1) if self.IS_STAGGERED else chunking_size ) if first_chunk_end >= total_flen: chunks_indices.append(torch.arange(total_flen)) else: chunks_indices.append(torch.arange(first_chunk_end)) for index in torch.arange(first_chunk_end, total_flen, 1).split(chunking_size): chunks_indices.append(index) if self.CHUNKING_MODE == "in_order": pass elif self.CHUNKING_MODE == "rand_order": if np.random.rand() > 0.5: chunks_indices = chunks_indices[::-1] else: raise NotImplementedError(f"{self.CHUNKING_MODE} mode not implemented!!") return chunks_indices def generate_samples_from_batch( self, data_batch: Dict, guidance: float = 1.5, seed: int = 1, state_shape: Tuple | None = None, n_sample: int | None = None, is_negative_prompt: bool = False, num_steps: int = 35, condition_latent: Union[torch.Tensor, None] = None, num_condition_t: Union[int, None] = None, condition_video_augment_sigma_in_inference: float = None, add_input_frames_guidance: bool = False, x_sigma_max: Optional[torch.Tensor] = None, sigma_max: Optional[float] = None, chunking: Optional[int] = None, **kwargs, ) -> Tensor: """Generates video samples conditioned on input frames. Args: data_batch: Input data dictionary guidance: Classifier-free guidance scale seed: Random seed for reproducibility state_shape: Shape of output tensor (defaults to model's state shape) n_sample: Number of samples to generate (defaults to batch size) is_negative_prompt: Whether to use negative prompting num_steps: Number of denoising steps condition_latent: Conditioning frames tensor (B,C,T,H,W) num_condition_t: Number of frames to condition on condition_video_augment_sigma_in_inference: Noise level for condition augmentation add_input_frames_guidance: Whether to apply guidance to input frames x_sigma_max: Maximum noise level tensor chunking: Chunking size, if None, chunking is disabled Returns: Generated video samples tensor """ if n_sample is None: input_key = self.input_data_key n_sample = data_batch[input_key].shape[0] if state_shape is None: log.debug(f"Default Video state shape is used. {self.state_shape}") state_shape = self.state_shape assert condition_latent is not None, "condition_latent should be provided" # try to add chunking here !!! log.info("x0_fn") x0_fn = self.get_x0_fn_from_batch_with_condition_latent( data_batch, guidance, is_negative_prompt=is_negative_prompt, condition_latent=condition_latent, num_condition_t=num_condition_t, condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, add_input_frames_guidance=add_input_frames_guidance, seed=seed, chunking=chunking, ) if sigma_max is None: sigma_max = self.sde.sigma_max if x_sigma_max is None: x_sigma_max = ( misc.arch_invariant_rand( (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed, ) * sigma_max ) if self.net.is_context_parallel_enabled: x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) if self.net.is_context_parallel_enabled: samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) return samples def get_x0_fn_from_batch_with_condition_latent( self, data_batch: Dict, guidance: float = 1.5, is_negative_prompt: bool = False, condition_latent: torch.Tensor = None, num_condition_t: Union[int, None] = None, condition_video_augment_sigma_in_inference: float = None, add_input_frames_guidance: bool = False, seed: int = 1, chunking: Optional[int] = None, ) -> Callable: """Creates denoising function for conditional video generation. Args: data_batch: Input data dictionary guidance: Classifier-free guidance scale is_negative_prompt: Whether to use negative prompting condition_latent: Conditioning frames tensor (B,C,T,H,W) num_condition_t: Number of frames to condition on condition_video_augment_sigma_in_inference: Noise level for condition augmentation add_input_frames_guidance: Whether to apply guidance to input frames seed: Random seed for reproducibility chunking: Chunking size, if None, chunking is disabled Returns: Function that takes noisy input and noise level and returns denoised prediction """ if chunking is None: log.info("no chunking") if is_negative_prompt: condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) else: condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) condition.video_cond_bool = True condition = self.add_condition_video_indicator_and_video_input_mask( condition_latent, condition, num_condition_t ) uncondition.video_cond_bool = False if add_input_frames_guidance else True uncondition = self.add_condition_video_indicator_and_video_input_mask( condition_latent, uncondition, num_condition_t ) def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: cond_x0 = self.denoise( noise_x, sigma, condition, condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, seed=seed, ).x0_pred_replaced uncond_x0 = self.denoise( noise_x, sigma, uncondition, condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, seed=seed, ).x0_pred_replaced return cond_x0 + guidance * (cond_x0 - uncond_x0) return x0_fn else: log.info("chunking !!!") def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: if is_negative_prompt: condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) else: condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) noises = torch.zeros_like(condition_latent) T = condition_latent.shape[2] for chunk_idx in self.get_chunks_indices(T, chunking): latents_ = condition_latent[:, :, chunk_idx, :, :] log.info(f"chunk_idx: {chunk_idx}, chunk shape: {latents_.shape}") # controlnet_cond_ = self.controlnet_data[:, chunk_idx] condition.video_cond_bool = True condition = self.add_condition_video_indicator_and_video_input_mask( latents_, condition, num_condition_t ) uncondition.video_cond_bool = False if add_input_frames_guidance else True uncondition = self.add_condition_video_indicator_and_video_input_mask( latents_, uncondition, num_condition_t ) cond_x0 = self.denoise( noise_x, sigma, condition, condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, seed=seed, ).x0_pred_replaced uncond_x0 = self.denoise( noise_x, sigma, uncondition, condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, seed=seed, ).x0_pred_replaced noises[:, :, chunk_idx, :, :] = cond_x0 + guidance * (cond_x0 - uncond_x0) # TODO: need scheduler ? return noises return x0_fn def add_condition_video_indicator_and_video_input_mask( self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None ) -> VideoExtendCondition: """Adds conditioning masks to VideoExtendCondition object. Creates binary indicators and input masks for conditional video generation. Args: latent_state: Input latent tensor (B,C,T,H,W) condition: VideoExtendCondition object to update num_condition_t: Number of frames to condition on Returns: Updated VideoExtendCondition with added masks: - condition_video_indicator: Binary tensor marking condition regions - condition_video_input_mask: Input mask for network - gt_latent: Ground truth latent tensor """ T = latent_state.shape[2] latent_dtype = latent_state.dtype condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( latent_dtype ) # 1 for condition region # Only in inference to decide the condition region assert num_condition_t is not None, "num_condition_t should be provided" assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" log.debug( f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" ) condition_video_indicator[:, :, :num_condition_t] += 1.0 condition.gt_latent = latent_state condition.condition_video_indicator = condition_video_indicator B, C, T, H, W = latent_state.shape # Create additional input_mask channel, this will be concatenated to the input of the network # See design doc section (Implementation detail A.1 and A.2) for visualization ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) assert condition.video_cond_bool is not None, "video_cond_bool should be set" # The input mask indicate whether the input is conditional region or not if condition.video_cond_bool: # Condition one given video frames condition.condition_video_input_mask = ( condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding ) else: # Unconditional case, use for cfg condition.condition_video_input_mask = zeros_padding return condition class DistillV2WModel(DistillT2WModel): """ControlNet Video2World Distillation Model.""" def augment_conditional_latent_frames( self, condition: VideoExtendCondition, cfg_video_cond_bool: VideoCondBoolConfig, gt_latent: Tensor, condition_video_augment_sigma_in_inference: float = 0.001, sigma: Tensor = None, seed: int = 1, ) -> Union[VideoExtendCondition, Tensor]: """Augments the conditional frames with noise during inference. Args: condition (VideoExtendCondition): condition object condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference sigma (Tensor): noise level for the generation region seed (int): random seed for reproducibility Returns: VideoExtendCondition: updated condition object condition_video_augment_sigma: sigma for the condition region, feed to the network augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W """ # Inference only, use fixed sigma for the condition region assert ( condition_video_augment_sigma_in_inference is not None ), "condition_video_augment_sigma_in_inference should be provided" augment_sigma = condition_video_augment_sigma_in_inference if augment_sigma >= sigma.flatten()[0]: # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. # This is achieved by setting all region as `generation`, i.e. value=0 log.debug("augment_sigma larger than sigma or other frame, remove condition") condition.condition_video_indicator = condition.condition_video_indicator * 0 B = gt_latent.shape[0] augment_sigma = torch.full((B,), augment_sigma, **self.tensor_kwargs) # Now apply the augment_sigma to the gt_latent noise = misc.arch_invariant_rand( gt_latent.shape, torch.float32, self.tensor_kwargs["device"], seed, ) augment_latent = gt_latent + noise * augment_sigma.view(B, 1, 1, 1, 1) _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input if condition.condition_video_indicator.sum() > 0: # has condition frames condition.condition_video_augment_sigma = c_noise_augment else: # no condition frames condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) # Multiply the whole latent with c_in_augment augment_latent_cin = batch_mul(augment_latent, c_in_augment) # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect _, _, c_in, _ = self.scaling(sigma=sigma) augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) return condition, augment_latent_cin def drop_out_condition_region( self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig ) -> Tensor: """Use for CFG on input frames, we drop out the conditional region There are two option: 1. when we dropout, we set the region to be zero 2. when we dropout, we set the region to be noise_x """ # Unconditional case, use for cfg if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": # Set the condition location input to be zero augment_latent_drop = torch.zeros_like(augment_latent) elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": # Set the condition location input to be noise_x, i.e., same as base model training augment_latent_drop = noise_x else: raise NotImplementedError( f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" ) return augment_latent_drop def denoise( self, noise_x: Tensor, sigma: Tensor, condition: VideoExtendCondition, condition_video_augment_sigma_in_inference: float = 0.001, seed: int = 1, ) -> VideoDenoisePrediction: """Denoises input tensor using conditional video generation. Args: noise_x (Tensor): Noisy input tensor. sigma (Tensor): Noise level. condition (VideoExtendCondition): Condition for denoising. condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference seed (int): Random seed for reproducibility Returns: VideoDenoisePrediction containing: - x0: Denoised prediction - eps: Noise prediction - logvar: Log variance of noise prediction - xt: Input before c_in multiplication - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth """ inputs_to_check = [noise_x, sigma, condition.gt_latent] for i, tensor in enumerate(inputs_to_check): if torch.isnan(tensor).any(): print(f"NaN found in input {i}") assert ( condition.gt_latent is not None ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" gt_latent = condition.gt_latent cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool condition_latent = gt_latent # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed condition, augment_latent = self.augment_conditional_latent_frames( condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed ) condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] if parallel_state.get_context_parallel_world_size() > 1: cp_group = parallel_state.get_context_parallel_group() condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) if not condition.video_cond_bool: # Unconditional case, drop out the condition region augment_latent = self.drop_out_condition_region(augment_latent, xt, cfg_video_cond_bool) # Compose the model input with condition region (augment_latent) and generation region (noise_x) new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x # Call the abse model denoise_pred = super().denoise(new_noise_xt, sigma, condition) x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 x0_pred = x0_pred_replaced return VideoDenoisePrediction( x0=x0_pred, eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), logvar=denoise_pred.logvar, xt=new_noise_xt, x0_pred_replaced=x0_pred_replaced, ) def add_condition_video_indicator_and_video_input_mask( self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None ) -> VideoExtendCondition: """Adds conditioning masks to VideoExtendCondition object. Creates binary indicators and input masks for conditional video generation. Args: latent_state: Input latent tensor (B,C,T,H,W) condition: VideoExtendCondition object to update num_condition_t: Number of frames to condition on Returns: Updated VideoExtendCondition with added masks: - condition_video_indicator: Binary tensor marking condition regions - condition_video_input_mask: Input mask for network - gt_latent: Ground truth latent tensor """ T = latent_state.shape[2] latent_dtype = latent_state.dtype condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( latent_dtype ) # 1 for condition region # Only in inference to decide the condition region assert num_condition_t is not None, "num_condition_t should be provided" assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" log.debug( f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" ) condition_video_indicator[:, :, :num_condition_t] += 1.0 condition.gt_latent = latent_state condition.condition_video_indicator = condition_video_indicator B, C, T, H, W = latent_state.shape # Create additional input_mask channel, this will be concatenated to the input of the network ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) assert condition.video_cond_bool is not None, "video_cond_bool should be set" # The input mask indicate whether the input is conditional region or not if condition.video_cond_bool: # Condition one given video frames condition.condition_video_input_mask = ( condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding ) else: # Unconditional case, use for cfg condition.condition_video_input_mask = zeros_padding return condition