Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass, fields | |
from enum import Enum | |
from typing import Any, Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from cosmos_predict1.diffusion.conditioner import GeneralConditioner | |
from cosmos_predict1.diffusion.functional.batch_ops import batch_mul | |
from cosmos_predict1.diffusion.training.context_parallel import split_inputs_cp | |
from cosmos_predict1.utils.misc import count_params | |
class DataType(Enum): | |
IMAGE = "image" | |
VIDEO = "video" | |
MIX = "mix" | |
class AbstractEmbModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._is_trainable = None | |
self._dropout_rate = None | |
self._input_key = None | |
self._return_dict = False | |
def is_trainable(self) -> bool: | |
return self._is_trainable | |
def dropout_rate(self) -> Union[float, torch.Tensor]: | |
return self._dropout_rate | |
def input_key(self) -> str: | |
return self._input_key | |
def is_return_dict(self) -> bool: | |
return self._return_dict | |
def is_trainable(self, value: bool): | |
self._is_trainable = value | |
def dropout_rate(self, value: Union[float, torch.Tensor]): | |
self._dropout_rate = value | |
def input_key(self, value: str): | |
self._input_key = value | |
def is_return_dict(self, value: bool): | |
self._return_dict = value | |
def is_trainable(self): | |
del self._is_trainable | |
def dropout_rate(self): | |
del self._dropout_rate | |
def input_key(self): | |
del self._input_key | |
def is_return_dict(self): | |
del self._return_dict | |
def random_dropout_input( | |
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None | |
) -> torch.Tensor: | |
del key | |
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate | |
return batch_mul( | |
torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), | |
in_tensor, | |
) | |
def details(self) -> str: | |
return "" | |
def summary(self) -> str: | |
input_key = self.input_key if self.input_key is not None else getattr(self, "input_keys", None) | |
return ( | |
f"{self.__class__.__name__} \n\tinput key: {input_key}" | |
f"\n\tParam count: {count_params(self, False)} \n\tTrainable: {self.is_trainable}" | |
f"\n\tDropout rate: {self.dropout_rate}" | |
f"\n\t{self.details()}" | |
) | |
class TrajectoryAttr(AbstractEmbModel): | |
def __init__(self, traj_dim: int): | |
super().__init__() | |
self.traj_dim = traj_dim | |
def forward(self, traj: torch.Tensor) -> Dict[str, torch.Tensor]: | |
return { | |
"trajectory": traj, | |
} | |
def details(self) -> str: | |
return f"Traj dim : {self.traj_dim} \n\tOutput key: [trajectory]" | |
class FrameRepeatAttr(AbstractEmbModel): | |
def __init__(self): | |
super().__init__() | |
def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]: | |
return { | |
"frame_repeat": frame_repeat / 10.0, | |
} | |
def details(self) -> str: | |
return "Frame repeat, Output key: [frame_repeat]" | |
class BaseVideoCondition: | |
crossattn_emb: torch.Tensor | |
crossattn_mask: torch.Tensor | |
data_type: DataType = DataType.VIDEO | |
padding_mask: Optional[torch.Tensor] = None | |
fps: Optional[torch.Tensor] = None | |
num_frames: Optional[torch.Tensor] = None | |
image_size: Optional[torch.Tensor] = None | |
scalar_feature: Optional[torch.Tensor] = None | |
trajectory: Optional[torch.Tensor] = None | |
frame_repeat: Optional[torch.Tensor] = None | |
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: | |
return {f.name: getattr(self, f.name) for f in fields(self)} | |
class VideoExtendCondition(BaseVideoCondition): | |
video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video | |
gt_latent: Optional[torch.Tensor] = None | |
condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region | |
# condition_video_input_mask will concat to the input of network, along channel dim; | |
# Will be concat with the input tensor | |
condition_video_input_mask: Optional[torch.Tensor] = None | |
# condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" | |
condition_video_augment_sigma: Optional[torch.Tensor] = None | |
# pose conditional input, will be concat with the input tensor | |
condition_video_pose: Optional[torch.Tensor] = None | |
class VideoLatentDiffusionDecoderCondition(BaseVideoCondition): | |
# latent_condition will concat to the input of network, along channel dim; | |
# cfg will make latent_condition all zero padding. | |
latent_condition: Optional[torch.Tensor] = None | |
latent_condition_sigma: Optional[torch.Tensor] = None | |
def get_condition_for_cp(self, cp_group): | |
self.latent_condition = split_inputs_cp(x=self.latent_condition, seq_dim=2, cp_group=cp_group) | |
self.latent_condition_sigma = split_inputs_cp(x=self.latent_condition_sigma, seq_dim=2, cp_group=cp_group) | |
class VideoConditioner(GeneralConditioner): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> BaseVideoCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return BaseVideoCondition(**output) | |
class VideoDiffusionDecoderConditioner(GeneralConditioner): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> VideoLatentDiffusionDecoderCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return VideoLatentDiffusionDecoderCondition(**output) | |
class VideoExtendConditioner(GeneralConditioner): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> VideoExtendCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return VideoExtendCondition(**output) | |
class VideoConditionerWithTraingOnlyEmb(GeneralConditioner): | |
def get_condition_uncondition( | |
self, | |
data_batch: Dict, | |
) -> Tuple[Any, Any]: | |
""" | |
Processes the provided data batch to generate two sets of outputs: conditioned and unconditioned. This method | |
manipulates the dropout rates of embedders to simulate two scenarios — one where all conditions are applied | |
(conditioned), and one where they are removed or reduced to the minimum (unconditioned). | |
This method first sets the dropout rates to zero for the conditioned scenario to fully apply the embedders' effects. | |
For the unconditioned scenario, it sets the dropout rates to 1 (or to 0 if the initial unconditional dropout rate | |
is insignificant) to minimize the embedders' influences, simulating an unconditioned generation. | |
Parameters: | |
data_batch (Dict): The input data batch that contains all necessary information for embedding processing. The | |
data is expected to match the required format and keys expected by the embedders. | |
Returns: | |
Tuple[Any, Any]: A tuple containing two condition: | |
- The first one contains the outputs with all embedders fully applied (conditioned outputs). | |
- The second one contains the outputs with embedders minimized or not applied (unconditioned outputs). | |
""" | |
cond_dropout_rates, dropout_rates = {}, {} | |
for emb_name, embedder in self.embedders.items(): | |
if isinstance(embedder, FrameRepeatAttr): | |
cond_dropout_rates[emb_name] = 1.0 | |
else: | |
cond_dropout_rates[emb_name] = 0.0 | |
dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 | |
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) | |
un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) | |
return condition, un_condition | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> BaseVideoCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return BaseVideoCondition(**output) | |
class VideoExtendConditionerWithTraingOnlyEmb(VideoConditionerWithTraingOnlyEmb): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> VideoExtendCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return VideoExtendCondition(**output) | |
class BaseWithCtrlCondition(VideoExtendCondition): | |
control_input_canny: Optional[torch.Tensor] = None | |
control_input_blur: Optional[torch.Tensor] = None | |
control_input_canny_blur: Optional[torch.Tensor] = None | |
control_input_depth: Optional[torch.Tensor] = None | |
control_input_segmentation: Optional[torch.Tensor] = None | |
control_input_depth_segmentation: Optional[torch.Tensor] = None | |
control_input_mask: Optional[torch.Tensor] = None | |
control_input_human_kpts: Optional[torch.Tensor] = None | |
control_input_upscale: Optional[torch.Tensor] = None | |
control_input_identity: Optional[torch.Tensor] = None | |
control_input_multi: Optional[torch.Tensor] = None | |
base_model: Optional[torch.nn.Module] = None | |
hint_key: Optional[str] = None | |
control_weight: Optional[float] = 1.0 | |
num_layers_to_use: Optional[int] = -1 | |
class VideoConditionerWithCtrl(VideoExtendConditioner): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> BaseWithCtrlCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
output["hint_key"] = batch["hint_key"] | |
if "control_weight" in batch: | |
output["control_weight"] = batch["control_weight"] | |
if "num_layers_to_use" in batch: | |
output["num_layers_to_use"] = batch["num_layers_to_use"] | |
return BaseWithCtrlCondition(**output) | |
class BooleanFlag(AbstractEmbModel): | |
def __init__(self, output_key: Optional[str] = None): | |
super().__init__() | |
self.output_key = output_key | |
def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: | |
del args, kwargs | |
key = self.output_key if self.output_key else self.input_key | |
return {key: self.flag} | |
def random_dropout_input( | |
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None | |
) -> torch.Tensor: | |
del key | |
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate | |
self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) | |
return in_tensor | |
def details(self) -> str: | |
key = self.output_key if self.output_key else self.input_key | |
return f"Output key: {key} \n\t This is a boolean flag" | |