Spaces:
Build error
Build error
File size: 10,159 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import EDMEulerScheduler
from megatron.core import parallel_state
from torch import Tensor
from cosmos_predict1.diffusion.conditioner import BaseVideoCondition
from cosmos_predict1.diffusion.module import parallel
from cosmos_predict1.diffusion.module.blocks import FourierFeatures
from cosmos_predict1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp
from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE
from cosmos_predict1.diffusion.training.utils.layer_control.peft_control_config_parser import LayerControlConfigParser
from cosmos_predict1.diffusion.training.utils.peft.peft import add_lora_layers, setup_lora_requires_grad
from cosmos_predict1.utils import log, misc
from cosmos_predict1.utils.distributed import get_rank
from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate
class DiffusionT2WModel(torch.nn.Module):
"""Text-to-world diffusion model that generates video frames from text descriptions.
This model implements a diffusion-based approach for generating videos conditioned on text input.
It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling,
and classifier-free guidance.
"""
def __init__(self, config):
"""Initialize the diffusion model.
Args:
config: Configuration object containing model parameters and architecture settings
"""
super().__init__()
# Initialize trained_data_record with defaultdict, key: image, video, iteration
self.config = config
self.precision = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[config.precision]
self.tensor_kwargs = {"device": "cuda", "dtype": self.precision}
log.debug(f"DiffusionModel: precision {self.precision}")
# Timer passed to network to detect slow ranks.
# 1. set data keys and data information
self.sigma_data = config.sigma_data
self.state_shape = list(config.latent_shape)
self.setup_data_key()
# 2. setup up diffusion processing and scaling~(pre-condition), sampler
self.scheduler = EDMEulerScheduler(sigma_max=80, sigma_min=0.0002, sigma_data=self.sigma_data)
self.tokenizer = None
self.model = None
@property
def net(self):
return self.model.net
@property
def conditioner(self):
return self.model.conditioner
@property
def logvar(self):
return self.model.logvar
def set_up_tokenizer(self, tokenizer_dir: str):
self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer)
self.tokenizer.load_weights(tokenizer_dir)
if hasattr(self.tokenizer, "reset_dtype"):
self.tokenizer.reset_dtype()
@misc.timer("DiffusionModel: set_up_model")
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
"""Initialize the core model components including network, conditioner and logvar."""
self.model = self.build_model()
if self.config.peft_control and self.config.peft_control.enabled:
log.info("Setting up LoRA layers")
peft_control_config_parser = LayerControlConfigParser(config=self.config.peft_control)
peft_control_config = peft_control_config_parser.parse()
add_lora_layers(self.model, peft_control_config)
num_lora_params = setup_lora_requires_grad(self.model)
self.model.requires_grad_(False)
if num_lora_params == 0:
raise ValueError("No LoRA parameters found. Please check the model configuration.")
self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs)
def build_model(self) -> torch.nn.ModuleDict:
"""Construct the model's neural network components.
Returns:
ModuleDict containing the network, conditioner and logvar components
"""
config = self.config
net = lazy_instantiate(config.net)
conditioner = lazy_instantiate(config.conditioner)
logvar = torch.nn.Sequential(
FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False)
)
return torch.nn.ModuleDict(
{
"net": net,
"conditioner": conditioner,
"logvar": logvar,
}
)
@torch.no_grad()
def encode(self, state: torch.Tensor) -> torch.Tensor:
"""Encode input state into latent representation using VAE.
Args:
state: Input tensor to encode
Returns:
Encoded latent representation scaled by sigma_data
"""
return self.tokenizer.encode(state) * self.sigma_data
@torch.no_grad()
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""Decode latent representation back to pixel space using VAE.
Args:
latent: Latent tensor to decode
Returns:
Decoded tensor in pixel space
"""
return self.tokenizer.decode(latent / self.sigma_data)
def setup_data_key(self) -> None:
"""Configure input data keys for video and image data."""
self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model
def generate_samples_from_batch(
self,
data_batch: dict,
guidance: float = 1.5,
seed: int = 1,
state_shape: tuple | None = None,
n_sample: int | None = 1,
is_negative_prompt: bool = False,
num_steps: int = 35,
) -> Tensor:
"""Generate samples from a data batch using diffusion sampling.
This function generates samples from either image or video data batches using diffusion sampling.
It handles both conditional and unconditional generation with classifier-free guidance.
Args:
data_batch (dict): Raw data batch from the training data loader
guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5.
seed (int, optional): Random seed for reproducibility. Defaults to 1.
state_shape (tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None.
n_sample (int | None, optional): Number of samples to generate. Defaults to 1.
is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False.
num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35.
Returns:
Tensor: Generated samples after diffusion sampling
"""
condition, uncondition = self._get_conditions(data_batch, is_negative_prompt)
self.scheduler.set_timesteps(num_steps)
xt = torch.randn(size=(n_sample,) + tuple(state_shape)) * self.scheduler.init_noise_sigma
to_cp = self.net.is_context_parallel_enabled
if to_cp:
xt = split_inputs_cp(x=xt, seq_dim=2, cp_group=self.net.cp_group)
for t in self.scheduler.timesteps:
xt = xt.to(**self.tensor_kwargs)
xt_scaled = self.scheduler.scale_model_input(xt, timestep=t)
# Predict the noise residual
t = t.to(**self.tensor_kwargs)
net_output_cond = self.net(x=xt_scaled, timesteps=t, **condition.to_dict())
net_output_uncond = self.net(x=xt_scaled, timesteps=t, **uncondition.to_dict())
net_output = net_output_cond + guidance * (net_output_cond - net_output_uncond)
# Compute the previous noisy sample x_t -> x_t-1
xt = self.scheduler.step(net_output, t, xt).prev_sample
samples = xt
if to_cp:
samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group)
return samples
def _get_conditions(
self,
data_batch: dict,
is_negative_prompt: bool = False,
):
"""Get the conditions for the model.
Args:
data_batch: Input data dictionary
is_negative_prompt: Whether to use negative prompting
Returns:
condition: Input conditions
uncondition: Conditions removed/reduced to minimum (unconditioned)
"""
if is_negative_prompt:
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
else:
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
to_cp = self.net.is_context_parallel_enabled
# For inference, check if parallel_state is initialized
if parallel_state.is_initialized():
condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp)
uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp)
return condition, uncondition
def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition:
condition_kwargs = {}
for k, v in condition.to_dict().items():
if isinstance(v, torch.Tensor):
assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it"
condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp)
condition = type(condition)(**condition_kwargs)
return condition
|