Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
from abc import ABC, abstractmethod | |
from collections import defaultdict | |
from dataclasses import dataclass, fields | |
from enum import Enum | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from cosmos_predict1.utils import log | |
from cosmos_predict1.utils.lazy_config import instantiate | |
class BaseConditionEntry(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._dropout_rate = None | |
self._input_key = None | |
self._return_dict = False | |
def dropout_rate(self) -> Union[float, torch.Tensor]: | |
return self._dropout_rate | |
def input_key(self) -> str: | |
return self._input_key | |
def is_return_dict(self) -> bool: | |
return self._return_dict | |
def dropout_rate(self, value: Union[float, torch.Tensor]): | |
self._dropout_rate = value | |
def input_key(self, value: str): | |
self._input_key = value | |
def is_return_dict(self, value: bool): | |
self._return_dict = value | |
def dropout_rate(self): | |
del self._dropout_rate | |
def input_key(self): | |
del self._input_key | |
def is_return_dict(self): | |
del self._return_dict | |
def random_dropout_input( | |
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None | |
) -> torch.Tensor: | |
del key | |
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate | |
bernoulli = torch.bernoulli((1.0 - dropout_rate) * torch.ones(len(in_tensor))).type_as(in_tensor) | |
bernoulli_expand = bernoulli.view((-1,) + (1,) * (in_tensor.dim() - 1)) | |
return bernoulli_expand * in_tensor | |
def summary(self) -> str: | |
pass | |
class DataType(Enum): | |
IMAGE = "image" | |
VIDEO = "video" | |
class TextAttr(BaseConditionEntry): | |
def __init__(self): | |
super().__init__() | |
def forward(self, token: torch.Tensor, mask: torch.Tensor): | |
return {"crossattn_emb": token, "crossattn_mask": mask} | |
def random_dropout_input( | |
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None | |
) -> torch.Tensor: | |
if key is not None and "mask" in key: | |
return in_tensor | |
return super().random_dropout_input(in_tensor, dropout_rate, key) | |
class BaseVideoCondition: | |
crossattn_emb: torch.Tensor | |
crossattn_mask: torch.Tensor | |
data_type: DataType = DataType.VIDEO | |
padding_mask: Optional[torch.Tensor] = None | |
fps: Optional[torch.Tensor] = None | |
num_frames: Optional[torch.Tensor] = None | |
image_size: Optional[torch.Tensor] = None | |
scalar_feature: Optional[torch.Tensor] = None | |
frame_repeat: Optional[torch.Tensor] = None | |
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: | |
return {f.name: getattr(self, f.name) for f in fields(self)} | |
class VideoExtendCondition(BaseVideoCondition): | |
video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video | |
gt_latent: Optional[torch.Tensor] = None | |
condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region | |
# condition_video_input_mask will concat to the input of network, along channel dim; | |
# Will be concat with the input tensor | |
condition_video_input_mask: Optional[torch.Tensor] = None | |
# condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" | |
condition_video_augment_sigma: Optional[torch.Tensor] = None | |
condition_video_pose: Optional[torch.Tensor] = None | |
class GeneralConditioner(nn.Module, ABC): | |
""" | |
An abstract module designed to handle various embedding models with conditional and | |
unconditional configurations. This abstract base class initializes and manages a collection | |
of embedders that can dynamically adjust their dropout rates based on conditioning. | |
Attributes: | |
KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. | |
embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and | |
configured based on the provided configurations. | |
Parameters: | |
emb_models (Union[List, Any]): A dictionary where keys are embedder names and values | |
are configurations for initializing the embedders. | |
""" | |
KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} | |
def __init__(self, **emb_models: Union[List, Any]): | |
super().__init__() | |
self.embedders = nn.ModuleDict() | |
for n, (emb_name, embconfig) in enumerate(emb_models.items()): | |
embedder = instantiate(embconfig.obj) | |
assert isinstance( | |
embedder, BaseConditionEntry | |
), f"embedder model {embedder.__class__.__name__} has to inherit from BaseConditionEntry" | |
embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) | |
if hasattr(embconfig, "input_key"): | |
embedder.input_key = embconfig.input_key | |
elif hasattr(embconfig, "input_keys"): | |
embedder.input_keys = embconfig.input_keys | |
else: | |
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") | |
log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") | |
self.embedders[emb_name] = embedder | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> Any: | |
"""Should be implemented in subclasses to handle conditon datatype""" | |
raise NotImplementedError | |
def _forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> Dict: | |
""" | |
Processes the input batch through all configured embedders, applying conditional dropout rates if specified. | |
Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. | |
Parameters: | |
batch (Dict): The input data batch to process. | |
override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates | |
per embedder key. | |
Returns: | |
Dict: A dictionary of output tensors concatenated by specified dimensions. | |
Note: | |
In case the network code is sensitive to the order of concatenation, you can either control the order via \ | |
config file or make sure the embedders return a unique key for each output. | |
""" | |
output = defaultdict(list) | |
if override_dropout_rate is None: | |
override_dropout_rate = {} | |
# make sure emb_name in override_dropout_rate is valid | |
for emb_name in override_dropout_rate.keys(): | |
assert emb_name in self.embedders, f"invalid name found {emb_name}" | |
for emb_name, embedder in self.embedders.items(): | |
with torch.no_grad(): | |
if hasattr(embedder, "input_key") and (embedder.input_key is not None): | |
emb_out = embedder( | |
embedder.random_dropout_input( | |
batch[embedder.input_key], override_dropout_rate.get(emb_name, None) | |
) | |
) | |
elif hasattr(embedder, "input_keys"): | |
emb_out = embedder( | |
*[ | |
embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) | |
for k in embedder.input_keys | |
] | |
) | |
for k, v in emb_out.items(): | |
output[k].append(v) | |
# Concatenate the outputs | |
return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} | |
def get_condition_uncondition( | |
self, | |
data_batch: Dict, | |
) -> Tuple[Any, Any]: | |
""" | |
Processes the provided data batch to generate conditioned and unconditioned outputs. | |
This method manipulates dropout rates to simulate two scenarios: | |
1. All conditions applied (conditioned) | |
2. Conditions removed/reduced to minimum (unconditioned) | |
This method sets dropout rates to zero for the conditioned scenario to fully apply | |
embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is | |
insignificant) to minimize embedder influences. | |
Parameters: | |
data_batch (Dict): Input data batch containing all necessary information for | |
embedding processing. | |
Returns: | |
Tuple[Any, Any]: A tuple containing: | |
- Outputs with all embedders fully applied (conditioned) | |
- Outputs with embedders minimized/not applied (unconditioned) | |
""" | |
cond_dropout_rates, dropout_rates = {}, {} | |
for emb_name, embedder in self.embedders.items(): | |
cond_dropout_rates[emb_name] = 0.0 | |
dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 | |
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) | |
un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) | |
return condition, un_condition | |
def get_condition_with_negative_prompt( | |
self, | |
data_batch: Dict, | |
) -> Tuple[Any, Any]: | |
""" | |
Similar functionality as get_condition_uncondition | |
But use negative prompts for unconditon | |
""" | |
cond_dropout_rates, uncond_dropout_rates = {}, {} | |
for emb_name, embedder in self.embedders.items(): | |
cond_dropout_rates[emb_name] = 0.0 | |
if isinstance(embedder, TextAttr): | |
uncond_dropout_rates[emb_name] = 0.0 | |
else: | |
uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 | |
data_batch_neg_prompt = copy.deepcopy(data_batch) | |
if "neg_t5_text_embeddings" in data_batch_neg_prompt: | |
if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): | |
data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] | |
data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] | |
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) | |
un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) | |
return condition, un_condition | |
class CosmosCondition: | |
crossattn_emb: torch.Tensor | |
crossattn_mask: torch.Tensor | |
padding_mask: Optional[torch.Tensor] = None | |
scalar_feature: Optional[torch.Tensor] = None | |
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: | |
return {f.name: getattr(self, f.name) for f in fields(self)} | |
class VideoConditioner(GeneralConditioner): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> BaseVideoCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return BaseVideoCondition(**output) | |
class VideoExtendConditioner(GeneralConditioner): | |
def forward( | |
self, | |
batch: Dict, | |
override_dropout_rate: Optional[Dict[str, float]] = None, | |
) -> VideoExtendCondition: | |
output = super()._forward(batch, override_dropout_rate) | |
return VideoExtendCondition(**output) | |