roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from cosmos_predict1.diffusion.conditioner import GeneralConditioner
from cosmos_predict1.diffusion.functional.batch_ops import batch_mul
from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp
from cosmos_predict1.utils.misc import count_params
class DataType(Enum):
IMAGE = "image"
VIDEO = "video"
MIX = "mix"
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._dropout_rate = None
self._input_key = None
self._return_dict = False
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def dropout_rate(self) -> Union[float, torch.Tensor]:
return self._dropout_rate
@property
def input_key(self) -> str:
return self._input_key
@property
def is_return_dict(self) -> bool:
return self._return_dict
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@dropout_rate.setter
def dropout_rate(self, value: Union[float, torch.Tensor]):
self._dropout_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_return_dict.setter
def is_return_dict(self, value: bool):
self._return_dict = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@dropout_rate.deleter
def dropout_rate(self):
del self._dropout_rate
@input_key.deleter
def input_key(self):
del self._input_key
@is_return_dict.deleter
def is_return_dict(self):
del self._return_dict
def random_dropout_input(
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
) -> torch.Tensor:
del key
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
return batch_mul(
torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor),
in_tensor,
)
def details(self) -> str:
return ""
def summary(self) -> str:
input_key = self.input_key if self.input_key is not None else getattr(self, "input_keys", None)
return (
f"{self.__class__.__name__} \n\tinput key: {input_key}"
f"\n\tParam count: {count_params(self, False)} \n\tTrainable: {self.is_trainable}"
f"\n\tDropout rate: {self.dropout_rate}"
f"\n\t{self.details()}"
)
class TrajectoryAttr(AbstractEmbModel):
def __init__(self, traj_dim: int):
super().__init__()
self.traj_dim = traj_dim
def forward(self, traj: torch.Tensor) -> Dict[str, torch.Tensor]:
return {
"trajectory": traj,
}
def details(self) -> str:
return f"Traj dim : {self.traj_dim} \n\tOutput key: [trajectory]"
class FrameRepeatAttr(AbstractEmbModel):
def __init__(self):
super().__init__()
def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]:
return {
"frame_repeat": frame_repeat / 10.0,
}
def details(self) -> str:
return "Frame repeat, Output key: [frame_repeat]"
@dataclass
class BaseVideoCondition:
crossattn_emb: torch.Tensor
crossattn_mask: torch.Tensor
data_type: DataType = DataType.VIDEO
padding_mask: Optional[torch.Tensor] = None
fps: Optional[torch.Tensor] = None
num_frames: Optional[torch.Tensor] = None
image_size: Optional[torch.Tensor] = None
scalar_feature: Optional[torch.Tensor] = None
trajectory: Optional[torch.Tensor] = None
frame_repeat: Optional[torch.Tensor] = None
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
return {f.name: getattr(self, f.name) for f in fields(self)}
@dataclass
class VideoExtendCondition(BaseVideoCondition):
video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video
gt_latent: Optional[torch.Tensor] = None
condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region
# condition_video_input_mask will concat to the input of network, along channel dim;
# Will be concat with the input tensor
condition_video_input_mask: Optional[torch.Tensor] = None
# condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed"
condition_video_augment_sigma: Optional[torch.Tensor] = None
# pose conditional input, will be concat with the input tensor
condition_video_pose: Optional[torch.Tensor] = None
@dataclass
class VideoLatentDiffusionDecoderCondition(BaseVideoCondition):
# latent_condition will concat to the input of network, along channel dim;
# cfg will make latent_condition all zero padding.
latent_condition: Optional[torch.Tensor] = None
latent_condition_sigma: Optional[torch.Tensor] = None
def get_condition_for_cp(self, cp_group):
self.latent_condition = split_inputs_cp(x=self.latent_condition, seq_dim=2, cp_group=cp_group)
self.latent_condition_sigma = split_inputs_cp(x=self.latent_condition_sigma, seq_dim=2, cp_group=cp_group)
class VideoConditioner(GeneralConditioner):
def forward(
self,
batch: Dict,
override_dropout_rate: Optional[Dict[str, float]] = None,
) -> BaseVideoCondition:
output = super()._forward(batch, override_dropout_rate)
return BaseVideoCondition(**output)
class VideoDiffusionDecoderConditioner(GeneralConditioner):
def forward(
self,
batch: Dict,
override_dropout_rate: Optional[Dict[str, float]] = None,
) -> VideoLatentDiffusionDecoderCondition:
output = super()._forward(batch, override_dropout_rate)
return VideoLatentDiffusionDecoderCondition(**output)
class VideoExtendConditioner(GeneralConditioner):
def forward(
self,
batch: Dict,
override_dropout_rate: Optional[Dict[str, float]] = None,
) -> VideoExtendCondition:
output = super()._forward(batch, override_dropout_rate)
return VideoExtendCondition(**output)
class VideoConditionerWithTraingOnlyEmb(GeneralConditioner):
def get_condition_uncondition(
self,
data_batch: Dict,
) -> Tuple[Any, Any]:
"""
Processes the provided data batch to generate two sets of outputs: conditioned and unconditioned. This method
manipulates the dropout rates of embedders to simulate two scenarios — one where all conditions are applied
(conditioned), and one where they are removed or reduced to the minimum (unconditioned).
This method first sets the dropout rates to zero for the conditioned scenario to fully apply the embedders' effects.
For the unconditioned scenario, it sets the dropout rates to 1 (or to 0 if the initial unconditional dropout rate
is insignificant) to minimize the embedders' influences, simulating an unconditioned generation.
Parameters:
data_batch (Dict): The input data batch that contains all necessary information for embedding processing. The
data is expected to match the required format and keys expected by the embedders.
Returns:
Tuple[Any, Any]: A tuple containing two condition:
- The first one contains the outputs with all embedders fully applied (conditioned outputs).
- The second one contains the outputs with embedders minimized or not applied (unconditioned outputs).
"""
cond_dropout_rates, dropout_rates = {}, {}
for emb_name, embedder in self.embedders.items():
if isinstance(embedder, FrameRepeatAttr):
cond_dropout_rates[emb_name] = 1.0
else:
cond_dropout_rates[emb_name] = 0.0
dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates)
return condition, un_condition
def forward(
self,
batch: Dict,
override_dropout_rate: Optional[Dict[str, float]] = None,
) -> BaseVideoCondition:
output = super()._forward(batch, override_dropout_rate)
return BaseVideoCondition(**output)
class VideoExtendConditionerWithTraingOnlyEmb(VideoConditionerWithTraingOnlyEmb):
def forward(
self,
batch: Dict,
override_dropout_rate: Optional[Dict[str, float]] = None,
) -> VideoExtendCondition:
output = super()._forward(batch, override_dropout_rate)
return VideoExtendCondition(**output)
@dataclass
class BaseWithCtrlCondition(VideoExtendCondition):
control_input_canny: Optional[torch.Tensor] = None
control_input_blur: Optional[torch.Tensor] = None
control_input_canny_blur: Optional[torch.Tensor] = None
control_input_depth: Optional[torch.Tensor] = None
control_input_segmentation: Optional[torch.Tensor] = None
control_input_depth_segmentation: Optional[torch.Tensor] = None
control_input_mask: Optional[torch.Tensor] = None
control_input_human_kpts: Optional[torch.Tensor] = None
control_input_upscale: Optional[torch.Tensor] = None
control_input_identity: Optional[torch.Tensor] = None
control_input_multi: Optional[torch.Tensor] = None
base_model: Optional[torch.nn.Module] = None
hint_key: Optional[str] = None
control_weight: Optional[float] = 1.0
num_layers_to_use: Optional[int] = -1
class VideoConditionerWithCtrl(VideoExtendConditioner):
def forward(
self,
batch: Dict,
override_dropout_rate: Optional[Dict[str, float]] = None,
) -> BaseWithCtrlCondition:
output = super()._forward(batch, override_dropout_rate)
output["hint_key"] = batch["hint_key"]
if "control_weight" in batch:
output["control_weight"] = batch["control_weight"]
if "num_layers_to_use" in batch:
output["num_layers_to_use"] = batch["num_layers_to_use"]
return BaseWithCtrlCondition(**output)
class BooleanFlag(AbstractEmbModel):
def __init__(self, output_key: Optional[str] = None):
super().__init__()
self.output_key = output_key
def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
del args, kwargs
key = self.output_key if self.output_key else self.input_key
return {key: self.flag}
def random_dropout_input(
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
) -> torch.Tensor:
del key
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device)
return in_tensor
def details(self) -> str:
key = self.output_key if self.output_key else self.input_key
return f"Output key: {key} \n\t This is a boolean flag"