|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Set, Tuple |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, CosmosCondition |
|
from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul |
|
from cosmos_transfer1.diffusion.diffusion.modules.denoiser_scaling import EDMScaling |
|
from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import Sampler |
|
from cosmos_transfer1.diffusion.diffusion.types import DenoisePrediction |
|
from cosmos_transfer1.diffusion.module import parallel |
|
from cosmos_transfer1.diffusion.module.blocks import FourierFeatures |
|
from cosmos_transfer1.diffusion.module.pretrained_vae import BaseVAE |
|
from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT |
|
from cosmos_transfer1.utils import log, misc |
|
from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate |
|
|
|
|
|
IS_PREPROCESSED_KEY = "is_preprocessed" |
|
from enum import Enum |
|
|
|
|
|
class DataType(Enum): |
|
IMAGE = "image" |
|
VIDEO = "video" |
|
MIX = "mix" |
|
|
|
|
|
class EDMSDE: |
|
def __init__( |
|
self, |
|
sigma_max: float, |
|
sigma_min: float, |
|
): |
|
self.sigma_max = sigma_max |
|
self.sigma_min = sigma_min |
|
|
|
|
|
class DiffusionT2WModel(torch.nn.Module): |
|
"""Text-to-world diffusion model that generates video frames from text descriptions. |
|
|
|
This model implements a diffusion-based approach for generating videos conditioned on text input. |
|
It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, |
|
and classifier-free guidance. |
|
""" |
|
|
|
def __init__(self, config): |
|
"""Initialize the diffusion model. |
|
|
|
Args: |
|
config: Configuration object containing model parameters and architecture settings |
|
""" |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
self.precision = { |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"bfloat16": torch.bfloat16, |
|
}[config.precision] |
|
self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} |
|
log.debug(f"DiffusionModel: precision {self.precision}") |
|
|
|
|
|
self.sigma_data = config.sigma_data |
|
self.state_shape = list(config.latent_shape) |
|
self.setup_data_key() |
|
|
|
|
|
self.sde = EDMSDE(sigma_max=80, sigma_min=0.0002) |
|
self.sampler = Sampler() |
|
self.scaling = EDMScaling(self.sigma_data) |
|
self.tokenizer = None |
|
self.model = None |
|
|
|
@property |
|
def net(self): |
|
return self.model.net |
|
|
|
@property |
|
def conditioner(self): |
|
return self.model.conditioner |
|
|
|
@property |
|
def logvar(self): |
|
return self.model.logvar |
|
|
|
def set_up_tokenizer(self, tokenizer_dir: str): |
|
self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) |
|
self.tokenizer.load_weights(tokenizer_dir) |
|
if hasattr(self.tokenizer, "reset_dtype"): |
|
self.tokenizer.reset_dtype() |
|
|
|
@misc.timer("DiffusionModel: set_up_model") |
|
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): |
|
"""Initialize the core model components including network, conditioner and logvar.""" |
|
self.model = self.build_model() |
|
self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) |
|
|
|
def build_model(self) -> torch.nn.ModuleDict: |
|
"""Construct the model's neural network components. |
|
|
|
Returns: |
|
ModuleDict containing the network, conditioner and logvar components |
|
""" |
|
config = self.config |
|
net = lazy_instantiate(config.net) |
|
conditioner = lazy_instantiate(config.conditioner) |
|
logvar = torch.nn.Sequential( |
|
FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) |
|
) |
|
|
|
return torch.nn.ModuleDict( |
|
{ |
|
"net": net, |
|
"conditioner": conditioner, |
|
"logvar": logvar, |
|
} |
|
) |
|
|
|
@torch.no_grad() |
|
def encode(self, state: torch.Tensor) -> torch.Tensor: |
|
"""Encode input state into latent representation using VAE. |
|
|
|
Args: |
|
state: Input tensor to encode |
|
|
|
Returns: |
|
Encoded latent representation scaled by sigma_data |
|
""" |
|
return self.tokenizer.encode(state) * self.sigma_data |
|
|
|
@torch.no_grad() |
|
def decode(self, latent: torch.Tensor) -> torch.Tensor: |
|
"""Decode latent representation back to pixel space using VAE. |
|
|
|
Args: |
|
latent: Latent tensor to decode |
|
|
|
Returns: |
|
Decoded tensor in pixel space |
|
""" |
|
return self.tokenizer.decode(latent / self.sigma_data) |
|
|
|
def setup_data_key(self) -> None: |
|
"""Configure input data keys for video and image data.""" |
|
self.input_data_key = self.config.input_data_key |
|
self.input_image_key = self.config.input_image_key |
|
|
|
def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: |
|
""" |
|
Performs denoising on the input noise data, noise level, and condition |
|
|
|
Args: |
|
xt (torch.Tensor): The input noise data. |
|
sigma (torch.Tensor): The noise level. |
|
condition (CosmosCondition): conditional information, generated from self.conditioner |
|
|
|
Returns: |
|
DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ |
|
noise prediction (eps_pred) and optional confidence (logvar). |
|
""" |
|
|
|
xt = xt.to(**self.tensor_kwargs) |
|
sigma = sigma.to(**self.tensor_kwargs) |
|
|
|
c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) |
|
|
|
|
|
net_output = self.net( |
|
x=batch_mul(c_in, xt), |
|
timesteps=c_noise, |
|
**condition.to_dict(), |
|
) |
|
|
|
logvar = self.model.logvar(c_noise) |
|
x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) |
|
|
|
|
|
eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) |
|
|
|
return DenoisePrediction(x0_pred, eps_pred, logvar) |
|
|
|
|
|
def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: |
|
""" |
|
Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. |
|
|
|
Args: |
|
tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). |
|
src (int): The source rank for the broadcast. Defaults to 0. |
|
|
|
Returns: |
|
torch.Tensor: The broadcasted tensor on all ranks. |
|
""" |
|
|
|
if distributed.get_rank() == src: |
|
shape = torch.tensor(tensor.shape).cuda() |
|
else: |
|
shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() |
|
if is_check_shape: |
|
_verify_param_shape_across_processes(pg, [shape]) |
|
torch.distributed.broadcast(shape, src, group=pg) |
|
|
|
|
|
if distributed.get_rank() != src: |
|
tensor = tensor.new_empty(shape.tolist()).type_as(tensor) |
|
|
|
|
|
torch.distributed.broadcast(tensor, src, group=pg) |
|
|
|
return tensor |
|
|
|
|
|
def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: |
|
""" |
|
Broadcast the item from the minimum rank in the specified group(s). |
|
Since global rank = tp_rank + cp_rank * tp_size + ... |
|
First broadcast in the tp_group and then in the cp_group will |
|
ensure that the item is broadcasted across ranks in cp_group and tp_group. |
|
|
|
Parameters: |
|
- item: The item to broadcast (can be a torch.Tensor, str, or None). |
|
- to_tp: Whether to broadcast to the tensor model parallel group. |
|
- to_cp: Whether to broadcast to the context parallel group. |
|
""" |
|
if not parallel_state.is_initialized(): |
|
return item |
|
tp_group = parallel_state.get_tensor_model_parallel_group() |
|
cp_group = parallel_state.get_context_parallel_group() |
|
|
|
to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 |
|
to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 |
|
|
|
if to_tp: |
|
min_tp_rank = min(get_process_group_ranks(tp_group)) |
|
|
|
if to_cp: |
|
min_cp_rank = min(get_process_group_ranks(cp_group)) |
|
|
|
if isinstance(item, torch.Tensor): |
|
if to_tp: |
|
item = robust_broadcast(item, min_tp_rank, tp_group) |
|
if to_cp: |
|
item = robust_broadcast(item, min_cp_rank, cp_group) |
|
elif item is not None: |
|
broadcastable_list = [item] |
|
if to_tp: |
|
broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) |
|
if to_cp: |
|
broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) |
|
|
|
item = broadcastable_list[0] |
|
return item |
|
|
|
|
|
class DistillT2WModel(DiffusionT2WModel): |
|
"""Base Video Distillation Model.""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, CosmosCondition]: |
|
self._normalize_video_databatch_inplace(data_batch) |
|
self._augment_image_dim_inplace(data_batch) |
|
input_key = self.input_data_key |
|
is_image_batch = self.is_image_batch(data_batch) |
|
is_video_batch = not is_image_batch |
|
|
|
|
|
|
|
local_keys = sorted(list(data_batch.keys())) |
|
for key in local_keys: |
|
data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch) |
|
|
|
if is_image_batch: |
|
input_key = self.input_image_key |
|
|
|
|
|
raw_state = data_batch[input_key] |
|
latent_state = self.encode(raw_state).contiguous() |
|
|
|
|
|
condition = self.conditioner(data_batch) |
|
if is_image_batch: |
|
condition.data_type = DataType.IMAGE |
|
else: |
|
condition.data_type = DataType.VIDEO |
|
|
|
|
|
|
|
latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch) |
|
condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch) |
|
|
|
def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: |
|
"""We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. |
|
Another comes from a dataloader which we by default assumes as video_data for video model training. |
|
""" |
|
is_image = self.input_image_key in data_batch |
|
is_video = self.input_data_key in data_batch |
|
assert ( |
|
is_image != is_video |
|
), "Only one of the input_image_key or input_data_key should be present in the data_batch." |
|
return is_image |
|
|
|
def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: |
|
""" |
|
Normalizes video data in-place on a CUDA device to reduce data loading overhead. |
|
|
|
This function modifies the video data tensor within the provided data_batch dictionary |
|
in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. |
|
|
|
Warning: |
|
A warning is issued if the data has not been previously normalized. |
|
|
|
Args: |
|
data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. |
|
This tensor is expected to be on a CUDA device and have dtype of torch.uint8. |
|
|
|
Side Effects: |
|
Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. |
|
|
|
Note: |
|
This operation is performed directly on the CUDA device to avoid the overhead associated |
|
with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device |
|
and has the correct dtype (torch.uint8) to avoid unexpected behaviors. |
|
""" |
|
input_key = self.input_data_key if input_key is None else input_key |
|
|
|
if input_key in data_batch: |
|
|
|
if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: |
|
assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." |
|
assert torch.all( |
|
(data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) |
|
), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" |
|
else: |
|
assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." |
|
data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 |
|
data_batch[IS_PREPROCESSED_KEY] = True |
|
|
|
def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: |
|
input_key = self.input_image_key if input_key is None else input_key |
|
if input_key in data_batch: |
|
|
|
if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: |
|
assert ( |
|
data_batch[input_key].shape[2] == 1 |
|
), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" |
|
return |
|
else: |
|
data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() |
|
data_batch[IS_PREPROCESSED_KEY] = True |
|
|
|
|
|
def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: |
|
condition_kwargs = {} |
|
for k, v in condition.to_dict().items(): |
|
if isinstance(v, torch.Tensor): |
|
assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" |
|
condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) |
|
condition = type(condition)(**condition_kwargs) |
|
return condition |
|
|