Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import functools | |
from contextlib import contextmanager | |
from dataclasses import dataclass, fields | |
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, TypeVar | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from megatron.core import parallel_state | |
from torch.distributed.fsdp import FullStateDictConfig | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp import ShardingStrategy, StateDictType | |
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy | |
from torch.nn.modules.module import _IncompatibleKeys | |
from cosmos_predict1.diffusion.functional.batch_ops import batch_mul | |
from cosmos_predict1.diffusion.module.blocks import FourierFeatures | |
from cosmos_predict1.diffusion.module.pretrained_vae import BaseVAE | |
from cosmos_predict1.diffusion.modules.denoiser_scaling import EDMScaling | |
from cosmos_predict1.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler | |
from cosmos_predict1.diffusion.training.functional.loss import create_per_sample_loss_mask | |
from cosmos_predict1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh | |
from cosmos_predict1.diffusion.training.utils.optim_instantiate import get_base_scheduler | |
from cosmos_predict1.diffusion.types import DenoisePrediction | |
from cosmos_predict1.utils import distributed, log, misc | |
from cosmos_predict1.utils.ema import FastEmaModelUpdater | |
from cosmos_predict1.utils.lazy_config import LazyDict | |
from cosmos_predict1.utils.lazy_config import instantiate as lazy_instantiate | |
from cosmos_predict1.utils.model import Model | |
class CosmosCondition: | |
crossattn_emb: torch.Tensor | |
crossattn_mask: torch.Tensor | |
padding_mask: Optional[torch.Tensor] = None | |
scalar_feature: Optional[torch.Tensor] = None | |
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: | |
return {f.name: getattr(self, f.name) for f in fields(self)} | |
class DiffusionModel(Model): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
# how many sample have been processed | |
self.sample_counter = 0 | |
self.precision = { | |
"float32": torch.float32, | |
"float16": torch.float16, | |
"bfloat16": torch.bfloat16, | |
}[config.precision] | |
self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} | |
log.warning(f"DiffusionModel: precision {self.precision}") | |
# Timer passed to network to detect slow ranks. | |
# 1. set data keys and data information | |
self.sigma_data = config.sigma_data | |
self.state_shape = list(config.latent_shape) | |
self.setup_data_key() | |
# 2. setup up diffusion processing and scaling~(pre-condition), sampler | |
self.sde = lazy_instantiate(config.sde) | |
self.sampler = Sampler() | |
self.scaling = EDMScaling(self.sigma_data) | |
# 3. vae | |
with misc.timer("DiffusionModel: set_up_vae"): | |
self.vae: BaseVAE = lazy_instantiate(config.vae) | |
assert ( | |
self.vae.latent_ch == self.state_shape[0] | |
), f"latent_ch {self.vae.latent_ch} != state_shape {self.state_shape[0]}" | |
# 4. Set up loss options, including loss masking, loss reduce and loss scaling | |
self.loss_masking: Optional[Dict] = config.loss_masking | |
self.loss_reduce = getattr(config, "loss_reduce", "mean") | |
assert self.loss_reduce in ["mean", "sum"] | |
self.loss_scale = getattr(config, "loss_scale", 1.0) | |
log.critical(f"Using {self.loss_reduce} loss reduce with loss scale {self.loss_scale}") | |
log.critical(f"Enable loss masking: {config.loss_mask_enabled}") | |
# 5. diffusion neural networks part | |
self.set_up_model() | |
def setup_data_key(self) -> None: | |
self.input_data_key = self.config.input_data_key | |
def build_model(self) -> torch.nn.ModuleDict: | |
config = self.config | |
net = lazy_instantiate(config.net) | |
conditioner = lazy_instantiate(config.conditioner) | |
logvar = torch.nn.Sequential( | |
FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) | |
) | |
return torch.nn.ModuleDict( | |
{ | |
"net": net, | |
"conditioner": conditioner, | |
"logvar": logvar, | |
} | |
) | |
def set_up_model(self): | |
config = self.config | |
self.model = self.build_model() | |
if config.ema.enabled: | |
with misc.timer("DiffusionModel: instantiate ema"): | |
config.ema.model = self.model | |
self.model_ema = lazy_instantiate(config.ema) | |
config.ema.model = None | |
else: | |
self.model_ema = None | |
def net(self): | |
return self.model.net | |
def conditioner(self): | |
return self.model.conditioner | |
def on_before_zero_grad( | |
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int | |
) -> None: | |
""" | |
update the model_ema | |
""" | |
if self.config.ema.enabled: | |
self.model_ema.update_average(self.model, iteration) | |
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: | |
if self.config.ema.enabled: | |
self.model_ema.to(dtype=torch.float32) | |
if hasattr(self.vae, "reset_dtype"): | |
self.vae.reset_dtype() | |
self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) | |
if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config | |
if torch.__version__ < "2.3": | |
log.warning( | |
"torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" | |
"It's very likely there will be no significant speedup from torch.compile.\n" | |
"Please use at least 24.04 Pytorch container." | |
) | |
# Increasing cache size. It's required because of the model size and dynamic input shapes resulting in | |
# multiple different triton kernels. For 28 TransformerBlocks, the cache limit of 256 should be enough for | |
# up to 9 different input shapes, as 28*9 < 256. If you have more Blocks or input shapes, and you observe | |
# graph breaks at each Block (detectable with torch._dynamo.explain) or warnings about | |
# exceeding cache limit, you may want to increase this size. | |
# Starting with 24.05 Pytorch container, the default value is 256 anyway. | |
# You can read more about it in the comments in Pytorch source code under path torch/_dynamo/cache_size.py. | |
torch._dynamo.config.accumulated_cache_size_limit = 256 | |
# dynamic=False means that a separate kernel is created for each shape. It incurs higher compilation costs | |
# at initial iterations, but can result in more specialized and efficient kernels. | |
# dynamic=True currently throws errors in pytorch 2.3. | |
self.model.net = torch.compile(self.model.net, dynamic=False, disable=not self.config.use_torch_compile) | |
def compute_loss_with_epsilon_and_sigma( | |
self, | |
data_batch: dict[str, torch.Tensor], | |
x0_from_data_batch: torch.Tensor, | |
x0: torch.Tensor, | |
condition: CosmosCondition, | |
epsilon: torch.Tensor, | |
sigma: torch.Tensor, | |
): | |
""" | |
Compute loss givee epsilon and sigma | |
This method is responsible for computing loss give epsilon and sigma. It involves: | |
1. Adding noise to the input data using the SDE process. | |
2. Passing the noisy data through the network to generate predictions. | |
3. Computing the loss based on the difference between the predictions and the original data, \ | |
considering any configured loss weighting. | |
Args: | |
data_batch (dict): raw data batch draw from the training data loader. | |
x0_from_data_batch: raw image/video | |
x0: image/video latent | |
condition: text condition | |
epsilon: noise | |
sigma: noise level | |
Returns: | |
tuple: A tuple containing four elements: | |
- dict: additional data that used to debug / logging / callbacks | |
- Tensor 1: kendall loss, | |
- Tensor 2: MSE loss, | |
- Tensor 3: EDM loss | |
Raises: | |
AssertionError: If the class is conditional, \ | |
but no number of classes is specified in the network configuration. | |
Notes: | |
- The method handles different types of conditioning | |
- The method also supports Kendall's loss | |
""" | |
# Get the mean and stand deviation of the marginal probability distribution. | |
mean, std = self.sde.marginal_prob(x0, sigma) | |
# Generate noisy observations | |
xt = mean + batch_mul(std, epsilon) # corrupted data | |
# make prediction | |
model_pred = self.denoise(xt, sigma, condition) | |
# loss weights for different noise levels | |
weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma) | |
# extra weight for each sample, for example, aesthetic weight, camera weight | |
weights_per_sample = self.get_per_sample_weight(data_batch, x0_from_data_batch.shape[0]) | |
# extra loss mask for each sample, for example, human faces, hands | |
loss_mask_per_sample = self.get_per_sample_loss_mask(data_batch, x0_from_data_batch.shape, x0.shape) | |
pred_mse = (x0 - model_pred.x0) ** 2 * loss_mask_per_sample | |
edm_loss = batch_mul(pred_mse, weights_per_sigma * weights_per_sample) | |
if self.config.loss_add_logvar: | |
kendall_loss = batch_mul(edm_loss, torch.exp(-model_pred.logvar).view(-1)).flatten( | |
start_dim=1 | |
) + model_pred.logvar.view(-1, 1) | |
else: | |
kendall_loss = edm_loss.flatten(start_dim=1) | |
output_batch = { | |
"x0": x0, | |
"xt": xt, | |
"sigma": sigma, | |
"weights_per_sigma": weights_per_sigma, | |
"weights_per_sample": weights_per_sample, | |
"loss_mask_per_sample": loss_mask_per_sample, | |
"condition": condition, | |
"model_pred": model_pred, | |
"mse_loss": pred_mse.mean(), | |
"edm_loss": edm_loss.mean(), | |
} | |
return output_batch, kendall_loss, pred_mse, edm_loss | |
def training_step( | |
self, data_batch: dict[str, torch.Tensor], iteration: int | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
""" | |
Performs a single training step for the diffusion model. | |
This method is responsible for executing one iteration of the model's training. It involves: | |
1. Adding noise to the input data using the SDE process. | |
2. Passing the noisy data through the network to generate predictions. | |
3. Computing the loss based on the difference between the predictions and the original data, \ | |
considering any configured loss weighting. | |
Args: | |
data_batch (dict): raw data batch draw from the training data loader. | |
iteration (int): Current iteration number. | |
Returns: | |
tuple: A tuple containing two elements: | |
- dict: additional data that used to debug / logging / callbacks | |
- Tensor: The computed loss for the training step as a PyTorch Tensor. | |
Raises: | |
AssertionError: If the class is conditional, \ | |
but no number of classes is specified in the network configuration. | |
Notes: | |
- The method handles different types of conditioning | |
- The method also supports Kendall's loss | |
""" | |
# Get the input data to noise and denoise~(image, video) and the corresponding conditioner. | |
x0_from_data_batch, x0, condition = self.get_data_and_condition(data_batch) | |
# Sample pertubation noise levels and N(0, 1) noises | |
sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition) | |
output_batch, kendall_loss, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma( | |
data_batch, x0_from_data_batch, x0, condition, epsilon, sigma | |
) | |
if self.loss_reduce == "mean": | |
kendall_loss = kendall_loss.mean() * self.loss_scale | |
elif self.loss_reduce == "sum": | |
kendall_loss = kendall_loss.sum(dim=1).mean() * self.loss_scale | |
else: | |
raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}") | |
return output_batch, kendall_loss | |
def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: | |
""" | |
Performs denoising on the input noise data, noise level, and condition | |
Args: | |
xt (torch.Tensor): The input noise data. | |
sigma (torch.Tensor): The noise level. | |
condition (CosmosCondition): conditional information, generated from self.conditioner | |
Returns: | |
DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ | |
noise prediction (eps_pred) and optional confidence (logvar). | |
""" | |
if getattr(self.config, "use_dummy_temporal_dim", False): | |
# When using video DiT model for image, we need to use a dummy temporal dimension. | |
xt = xt.unsqueeze(2) | |
xt = xt.to(**self.tensor_kwargs) | |
sigma = sigma.to(**self.tensor_kwargs) | |
# get precondition for the network | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) | |
# forward pass through the network | |
net_output = self.net( | |
x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf | |
timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf | |
**condition.to_dict(), | |
) | |
logvar = self.model.logvar(c_noise) | |
x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) | |
# get noise prediction based on sde | |
eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) | |
if getattr(self.config, "use_dummy_temporal_dim", False): | |
x0_pred = x0_pred.squeeze(2) | |
eps_pred = eps_pred.squeeze(2) | |
return DenoisePrediction(x0_pred, eps_pred, logvar) | |
def encode(self, state: torch.Tensor) -> torch.Tensor: | |
return self.vae.encode(state) * self.sigma_data | |
def decode(self, latent: torch.Tensor) -> torch.Tensor: | |
return self.vae.decode(latent / self.sigma_data) | |
def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor: | |
del condition | |
batch_size = x0_size[0] | |
epsilon = torch.randn(x0_size, **self.tensor_kwargs) | |
return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon | |
def get_data_and_condition(self, data_batch: dict[str, torch.Tensor]) -> Tuple[torch.Tensor, CosmosCondition]: | |
""" | |
processing data batch draw from data loader and return data and condition that used for denoising task | |
Returns: | |
raw_state (tensor): the image / video data that feed to vae | |
latent_state (tensor): nosie-free state, the vae latent state | |
condition (CosmosCondition): condition information for conditional generation. Generated from conditioner | |
""" | |
raw_state = data_batch[self.input_data_key] | |
latent_state = self.encode(raw_state) | |
condition = self.conditioner(data_batch) | |
return raw_state, latent_state, condition | |
def get_per_sample_weight(self, data_batch: dict[str, torch.Tensor], batch_size: int): | |
r""" | |
extra weight for each sample, for example, aesthetic weight | |
Args: | |
data_batch: raw data batch draw from the training data loader. | |
batch_size: int, the batch size of the input data | |
""" | |
aesthetic_cfg = getattr(self.config, "aesthetic_finetuning", None) | |
if (aesthetic_cfg is not None) and getattr(aesthetic_cfg, "enabled", False): | |
sample_weight = data_batch["aesthetic_weight"] | |
else: | |
sample_weight = torch.ones(batch_size, **self.tensor_kwargs) | |
camera_cfg = getattr(self.config, "camera_sample_weight", None) | |
if (camera_cfg is not None) and getattr(camera_cfg, "enabled", False): | |
sample_weight *= 1 + (data_batch["camera_attributes"][:, 1:].sum(dim=1) != 0) * (camera_cfg.weight - 1) | |
return sample_weight | |
def get_per_sample_loss_mask(self, data_batch, raw_x_shape, latent_x_shape): | |
""" | |
extra loss mask for each sample, for example, human faces, hands. | |
Args: | |
data_batch (dict): raw data batch draw from the training data loader. | |
raw_x_shape (tuple): shape of the input data. We need the raw_x_shape for necessary resize operation. | |
latent_x_shape (tuple): shape of the latent data | |
""" | |
if self.config.loss_mask_enabled: | |
raw_x_shape = [raw_x_shape[0], 1, *raw_x_shape[2:]] | |
weights = create_per_sample_loss_mask( | |
self.loss_masking, data_batch, raw_x_shape, torch.get_default_dtype(), "cuda" | |
) | |
return F.interpolate(weights, size=latent_x_shape[2:], mode="bilinear") | |
return 1.0 | |
def get_per_sigma_loss_weights(self, sigma: torch.Tensor): | |
""" | |
Args: | |
sigma (tensor): noise level | |
Returns: | |
loss weights per sigma noise level | |
""" | |
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 | |
def generate_samples(self, batch_size: int, condition: CosmosCondition) -> torch.Tensor: | |
""" | |
Generate samples with given condition. It is WITHOUT classifier-free-guidance. | |
Args: | |
batch_size (int): | |
condition (CosmosCondition): condition information generated from self.conditioner | |
""" | |
x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max | |
def x0_fn(x, t): | |
return self.denoise(x, t, condition).x0 # ODE function | |
return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) | |
def generate_cfg_samples( | |
self, batch_size: int, condition: CosmosCondition, uncondition: CosmosCondition, guidance=1.5 | |
) -> torch.Tensor: | |
""" | |
Generate samples with with classifier-free-guidance. | |
Args: | |
batch_size (int): | |
condition (CosmosCondition): condition information generated from self.conditioner | |
uncondition (CosmosCondition): uncondition information, possibily generated from self.conditioner | |
""" | |
x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max | |
def x0_fn(x, t): | |
cond_x0 = self.denoise(x, t, condition).x0 | |
uncond_x0 = self.denoise(x, t, uncondition).x0 | |
return cond_x0 + guidance * (cond_x0 - uncond_x0) | |
return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) | |
def get_x0_fn_from_batch( | |
self, | |
data_batch: Dict, | |
guidance: float = 1.5, | |
is_negative_prompt: bool = False, | |
) -> Callable: | |
""" | |
Generates a callable function `x0_fn` based on the provided data batch and guidance factor. | |
This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. | |
Args: | |
- data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` | |
- guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. | |
- is_negative_prompt (bool): use negative prompt t5 in uncondition if true | |
Returns: | |
- Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin | |
The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. | |
""" | |
if is_negative_prompt: | |
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) | |
else: | |
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) | |
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: | |
cond_x0 = self.denoise(noise_x, sigma, condition).x0 | |
uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 | |
return cond_x0 + guidance * (cond_x0 - uncond_x0) | |
return x0_fn | |
def generate_samples_from_batch( | |
self, | |
data_batch: Dict, | |
guidance: float = 1.5, | |
seed: int = 1, | |
state_shape: Optional[Tuple] = None, | |
n_sample: Optional[int] = None, | |
is_negative_prompt: bool = False, | |
num_steps: int = 35, | |
solver_option: COMMON_SOLVER_OPTIONS = "2ab", | |
) -> torch.Tensor: | |
""" | |
Args: | |
data_batch (dict): raw data batch draw from the training data loader. | |
iteration (int): Current iteration number. | |
guidance (float): guidance weights | |
seed (int): random seed | |
state_shape (tuple): shape of the state, default to self.state_shape if not provided | |
n_sample (int): number of samples to generate | |
is_negative_prompt (bool): use negative prompt t5 in uncondition if true | |
num_steps (int): number of steps for the diffusion process | |
solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) | |
""" | |
x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) | |
batch_size = n_sample or data_batch[self.input_data_key].shape[0] | |
state_shape = state_shape or self.state_shape | |
x_sigma_max = ( | |
misc.arch_invariant_rand( | |
(batch_size,) + tuple(state_shape), | |
torch.float32, | |
self.tensor_kwargs["device"], | |
seed, | |
) | |
* self.sde.sigma_max | |
) | |
return self.sampler( | |
x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max, num_steps=num_steps, solver_option=solver_option | |
) | |
def validation_step( | |
self, data: dict[str, torch.Tensor], iteration: int | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
""" | |
Current code does nothing. | |
""" | |
return {}, torch.tensor(0).to(**self.tensor_kwargs) | |
def forward(self, xt, t, condition: CosmosCondition): | |
""" | |
Performs denoising on the input noise data, noise level, and condition | |
Args: | |
xt (torch.Tensor): The input noise data. | |
sigma (torch.Tensor): The noise level. | |
condition (CosmosCondition): conditional information, generated from self.conditioner | |
Returns: | |
DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ | |
noise prediction (eps_pred) and optional confidence (logvar). | |
""" | |
return self.denoise(xt, t, condition) | |
def init_optimizer_scheduler( | |
self, optimizer_config: LazyDict, scheduler_config: LazyDict | |
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: | |
"""Creates the optimizer and scheduler for the model. | |
Args: | |
config_model (ModelConfig): The config object for the model. | |
Returns: | |
optimizer (torch.optim.Optimizer): The model optimizer. | |
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
""" | |
optimizer = lazy_instantiate(optimizer_config, model=self.model) | |
scheduler = get_base_scheduler(optimizer, self, scheduler_config) | |
return optimizer, scheduler | |
def state_dict(self) -> Dict[str, Any]: | |
""" | |
Returns the current state of the model as a dictionary. | |
Returns: | |
Dict: The current state of the model as a dictionary. | |
""" | |
return { | |
"model": self.model.state_dict(), | |
"ema": self.model_ema.state_dict() if self.config.ema.enabled else None, | |
} | |
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): | |
""" | |
Loads a state dictionary into the model and optionally its EMA counterpart. | |
Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning. | |
Parameters: | |
state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and | |
potentially for an EMA version of the model under the keys 'model' and 'ema', respectively. | |
strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly | |
those in the model and EMA model (if applicable). Defaults to True. | |
assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than | |
matching keys one-by-one. This is typically used when loading parts of state dicts | |
or using customized loading procedures. Defaults to False. | |
""" | |
if strict: | |
# the converted tpsp checkpoint has "ema" and it is None | |
if self.config.ema.enabled and state_dict["ema"] is not None: | |
ema_results: _IncompatibleKeys = self.model_ema.load_state_dict( | |
state_dict["ema"], strict=strict, assign=assign | |
) | |
reg_results: _IncompatibleKeys = self.model.load_state_dict( | |
state_dict["model"], strict=strict, assign=assign | |
) | |
if self.config.ema.enabled and state_dict["ema"] is not None: | |
return _IncompatibleKeys( | |
ema_results.missing_keys + reg_results.missing_keys, | |
ema_results.unexpected_keys + reg_results.unexpected_keys, | |
) | |
return reg_results | |
else: | |
from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model | |
log.critical("load model in non-strict mode") | |
if "model" in state_dict: | |
log.critical(non_strict_load_model(self.model, state_dict["model"]), rank0_only=False) | |
else: | |
log.critical(non_strict_load_model(self.model, state_dict), rank0_only=False) | |
if self.config.ema.enabled and "ema" in state_dict and state_dict["ema"] is not None: | |
log.critical("load ema model in non-strict mode") | |
log.critical(non_strict_load_model(self.model_ema, state_dict["ema"]), rank0_only=False) | |
def get_ckpt_postfix(self) -> Tuple[str, int, int]: | |
"""Get the checkpoint file postfix. | |
Args: | |
iteration (int): The current iteration number. | |
Returns: | |
postfix (str): The postfix of the checkpoint file. | |
rank_to_save ema (int), we will not save each ema model in each rank, \ | |
ema model with same rate will be saved once | |
total_ema_num (int) | |
""" | |
total_ema_num = min(self.config.ema.num, distributed.get_world_size()) | |
rank = distributed.get_rank() | |
if rank == 0: | |
return "", 0, total_ema_num | |
if self.config.ema.enabled: | |
if rank < self.config.ema.num: | |
return f"_RANK{rank}", rank, total_ema_num | |
return "", 0, total_ema_num # use rank 0 to save the checkpoint | |
def ema_scope(self, context=None, is_cpu=False): | |
if self.config.ema.enabled: | |
self.model_ema.cache(self.model.parameters(), is_cpu=is_cpu) | |
self.model_ema.copy_to(self.model) | |
if context is not None: | |
log.info(f"{context}: Switched to EMA weights") | |
try: | |
yield None | |
finally: | |
if self.config.ema.enabled: | |
self.model_ema.restore(self.model.parameters()) | |
if context is not None: | |
log.info(f"{context}: Restored training weights") | |
T = TypeVar("T", bound=DiffusionModel) | |
def diffusion_fsdp_class_decorator(base_class: Type[T]) -> Type[T]: | |
""" | |
Decorator for the FSDP class for the diffusion model, which handles the FSDP specific logic for the diffusion model. | |
""" | |
class FSDPClass(base_class): | |
""" | |
Handle FSDP specific logic for the diffusion model. Including: | |
- FSDP model initialization | |
- FSDP model / optimizer save and loading | |
- Different from the original DiffusionModel, the impl of multi-rank EMA is a bit hacky. \ | |
We need to make sure sharded model weights for EMA and regular model are the same. | |
""" | |
def __init__(self, config, fsdp_checkpointer: Any): | |
self.fsdp_checkpointer = fsdp_checkpointer | |
super().__init__(config) | |
def set_up_model(self): | |
config = self.config | |
# 1. build FSDP sharding strategy and device_mesh | |
strategy = { | |
"full": ShardingStrategy.FULL_SHARD, | |
"hybrid": ShardingStrategy.HYBRID_SHARD, | |
}[config.fsdp.sharding_strategy] | |
log.critical(f"Using {strategy} sharding strategy for FSDP") | |
if config.fsdp.sharding_strategy == "hybrid": | |
sharding_group_size = getattr(config.fsdp, "sharding_group_size", 8) | |
device_mesh = hsdp_device_mesh( | |
sharding_group_size=sharding_group_size, | |
) | |
shard_group = device_mesh.get_group(mesh_dim="shard") | |
replicate_group = device_mesh.get_group(mesh_dim="replicate") | |
fsdp_process_group = (shard_group, replicate_group) | |
else: | |
device_mesh = hsdp_device_mesh( | |
sharding_group_size=distributed.get_world_size(), | |
) | |
shard_group = device_mesh.get_group(mesh_dim="shard") | |
fsdp_process_group = shard_group | |
# We piggyback the `device_mesh` to megatron-core's `parallel_state` for global access. | |
# This is not megatron-core's original API. | |
parallel_state.fsdp_device_mesh = device_mesh | |
def get_wrap_policy(_model): | |
if not hasattr(_model.net, "fsdp_wrap_block_cls"): | |
raise ValueError( | |
"Networks does not have fsdp_wrap_block_cls attribute, please check the net definition" | |
) | |
fsdp_blocks_cls = _model.net.fsdp_wrap_block_cls | |
fsdp_blocks_cls = ( | |
list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] | |
) | |
log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") | |
log.critical(f"Using wrap policy {config.fsdp.policy}") | |
if config.fsdp.policy == "size": | |
min_num_params = getattr(config.fsdp, "min_num_params", 100) | |
log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") | |
wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) | |
else: | |
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
wrap_policy = functools.partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls=set(fsdp_blocks_cls), | |
) | |
return wrap_policy | |
# 2. build naive pytorch model and load weights if exists | |
replica_idx, shard_idx = device_mesh.get_coordinate() | |
# 2.1 handle ema case first, since float32 is more expensive | |
if config.ema.enabled: | |
with misc.timer("Creating PyTorch model and loading weights for ema"): | |
model_ema = self.build_model().float() | |
model_ema.cuda().eval().requires_grad_(False) | |
if distributed.get_rank() == 0: | |
# only load model in rank0 to reduce network traffic | |
self.fsdp_checkpointer.load_model_during_init(model_ema, is_ema=True) | |
# sync ema model weights from rank0 | |
with misc.timer("Sync model states for EMA model"): | |
#! this is IMPORTANT, see the following comment about regular model for details | |
#! we broadcast the ema model first, since it is fp32 and costs more memory | |
distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="shard")) | |
torch.cuda.empty_cache() | |
distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="replicate")) | |
torch.cuda.empty_cache() | |
# for ema model with dfiferent rate, we download the model when necessary | |
if shard_idx == 0 and replica_idx > 0 and replica_idx < config.ema.num: | |
print("loading ema model in rank", replica_idx) | |
self.fsdp_checkpointer.load_model_during_init( | |
model_ema, | |
is_ema=True, | |
ema_id=replica_idx, | |
) | |
print("finish loading ema model in rank", replica_idx) | |
# 2.1.2 create FSDP model for ema model | |
with misc.timer("Creating FSDP model for EMA model"): | |
self.model_ema = FSDP( | |
model_ema, | |
sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync | |
process_group=device_mesh.get_group(mesh_dim=1), | |
sharding_strategy=ShardingStrategy.FULL_SHARD, | |
auto_wrap_policy=get_wrap_policy(model_ema), | |
device_id=torch.cuda.current_device(), | |
limit_all_gathers=True, | |
) | |
# extra ema model upate logic to the model | |
self.model_ema_worker = FastEmaModelUpdater() | |
s = 0.1 | |
replica_idx, shard_idx = device_mesh.get_coordinate() | |
divider = 2**replica_idx if replica_idx < config.ema.num else 1 | |
if replica_idx < config.ema.num: | |
if shard_idx == 0: | |
print(f"EMA: rank {replica_idx}, rate {config.ema.rate / divider}") | |
s = config.ema.rate / divider | |
self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() | |
torch.cuda.empty_cache() | |
# 2.2 handle regular model | |
with misc.timer("Creating PyTorch model and loading weights for regular model"): | |
model = self.build_model().cuda().to(**self.tensor_kwargs) | |
if distributed.get_rank() == 0: | |
# only load model in rank0 to reduce network traffic and sync later | |
self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) | |
#! overwrite the forward method so that it will invoke the FSDP-specific pre- and post-forward sharding logic | |
model.forward = super().training_step | |
#! this is IMPORTANT, though following two lines are identical to sync_module_states=True in FSDP | |
#! we do it twice so that following line can warm up and avoid OOM in aws 128+ nodes settings | |
#! qsh hypothesize that it is due to overhead of initialization of nccl network communication; | |
#! without it, peak mem : reg_model + ema_model + FSDP overhead + nccl communication initialization overhead | |
#! with it, peak men: reg_model + ema_model + FSDP overhead | |
#! it is tricky, but it works! | |
with misc.timer("Sync model states for regular model"): | |
distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="shard")) | |
torch.cuda.empty_cache() | |
distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="replicate")) | |
torch.cuda.empty_cache() | |
with misc.timer("Creating FSDP model"): | |
self.model = FSDP( | |
model.to(**self.tensor_kwargs), | |
sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync | |
sharding_strategy=strategy, | |
auto_wrap_policy=get_wrap_policy(model), | |
process_group=fsdp_process_group, | |
limit_all_gathers=True, | |
) | |
if self.config.fsdp.checkpoint: | |
fsdp_blocks_cls = model.net.fsdp_wrap_block_cls | |
fsdp_blocks_cls = ( | |
list(fsdp_blocks_cls) | |
if isinstance(fsdp_blocks_cls, (list, tuple, set)) | |
else [fsdp_blocks_cls] | |
) | |
log.critical(f"Applying FSDP checkpointing with FSDP blocks: {fsdp_blocks_cls}") | |
apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) | |
torch.cuda.empty_cache() | |
def on_before_zero_grad( | |
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int | |
) -> None: | |
del scheduler, optimizer | |
if self.config.ema.enabled: | |
# calculate beta for EMA update | |
if iteration == 0: | |
beta = 0.0 | |
else: | |
i = iteration + 1 | |
beta = (1 - 1 / i) ** (self.ema_exp_coefficient + 1) | |
self.model_ema_worker.update_average(self.model, self.model_ema, beta=beta) | |
def training_step( | |
self, data_batch: Dict[str, torch.Tensor], iteration: int | |
) -> Tuple[Dict[str, torch.Tensor] | torch.Tensor]: | |
# ! Important!!! | |
# ! make sure the training step is the same as the forward method~(training_step in the super class) | |
# ! this is necessary to trigger the FSDP-specific pre- and post-forward sharding logic | |
return self.model(data_batch, iteration) | |
def state_dict(self) -> Dict: | |
raise NotImplementedError( | |
"FSDPDiffModle does not support state_dict, use state_dict_model and FSDPCheckpointer" | |
) | |
def state_dict_model(self) -> Dict: | |
with FSDP.summon_full_params(self.model): | |
pass | |
with FSDP.state_dict_type( | |
self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
): | |
model_state = self.model.state_dict() | |
if self.config.ema.enabled: | |
with FSDP.summon_full_params(self.model_ema): | |
pass | |
with FSDP.state_dict_type( | |
self.model_ema, | |
StateDictType.FULL_STATE_DICT, | |
FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
): | |
ema_model_state = self.model_ema.state_dict() | |
else: | |
ema_model_state = None | |
return { | |
"model": model_state, | |
"ema": ema_model_state, | |
} | |
def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: | |
raise NotImplementedError("FSDPDiffModle does not support load_state_dict, using FSDPCheckpointer") | |
def init_optimizer_scheduler( | |
self, optimizer_config: LazyDict, scheduler_config: LazyDict | |
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: | |
optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) | |
self.fsdp_checkpointer.load_optim_scheduler_during_init( | |
self.model, | |
optimizer, | |
scheduler, | |
) | |
return optimizer, scheduler | |
def ema_scope(self, context=None, is_cpu=False): | |
if self.config.ema.enabled: | |
self.model_ema_worker.cache(self.model.parameters(), is_cpu=is_cpu) | |
self.model_ema_worker.copy_to(src_model=self.model_ema, tgt_model=self.model) | |
if context is not None: | |
log.info(f"{context}: Switched to EMA weights") | |
try: | |
yield None | |
finally: | |
if self.config.ema.enabled: | |
self.model_ema_worker.restore(self.model.parameters()) | |
if context is not None: | |
log.info(f"{context}: Restored training weights") | |
def get_ckpt_postfix(self) -> Tuple[str, int]: | |
"""Get the checkpoint file postfix. check FSDPCheckpointer for more details | |
Args: | |
iteration (int): The current iteration number. | |
Returns: | |
postfix (str): The postfix of the checkpoint file. | |
replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ | |
we will not save each ema model in each GPU, \ | |
ema model with same rate will be saved once | |
total_ema_num (int) | |
""" | |
mesh_shape = parallel_state.fsdp_device_mesh.shape | |
total_ema_num = min(self.config.ema.num, mesh_shape[0]) | |
replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() | |
if replicate_idx == 0: | |
return "", 0, shard_idx, total_ema_num | |
if self.config.ema.enabled: | |
if replicate_idx < self.config.ema.num: | |
return f"_RANK{replicate_idx}", replicate_idx, shard_idx, total_ema_num | |
return "", replicate_idx, shard_idx, total_ema_num | |
return FSDPClass | |
class FSDPDiffusionModel(DiffusionModel): | |
pass | |