Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import functools | |
import os | |
import signal | |
import torch | |
import torch.distributed as dist | |
import torch.utils.data | |
from megatron.core import parallel_state | |
from cosmos_predict1.utils import callback, distributed, ema, log, misc | |
from cosmos_predict1.utils.checkpointer import Checkpointer | |
from cosmos_predict1.utils.lazy_config import LazyConfig, instantiate | |
from cosmos_predict1.utils.model import Model | |
class Trainer: | |
"""The base trainer class. | |
All trainers should inherit Trainer. It contains the basic functionality for model training | |
(particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), | |
mixed-precision training (fp16/bf16). | |
Attributes: | |
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. | |
training_timer (misc.Timer): Timer object to time code blocks and functions. | |
""" | |
def __init__(self, config): | |
"""Constructor of the trainer. | |
Args: | |
config (Config): The config object for the codebase. | |
""" | |
super().__init__() | |
self.config = config | |
# Set up the distributed computing environment. | |
with misc.timer("init_distributed"): | |
distributed.init() | |
# Set up parallel states. | |
if hasattr(config.model, "context_parallel_size"): | |
if config.model_parallel.context_parallel_size > 1: | |
raise ValueError( | |
"Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " | |
"config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." | |
) | |
else: | |
log.critical( | |
"Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." | |
) | |
config.model_parallel.context_parallel_size = config.model.context_parallel_size | |
parallel_state.initialize_model_parallel( | |
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, | |
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, | |
context_parallel_size=config.model_parallel.context_parallel_size, | |
) | |
# `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. | |
# It is not part of the original `parallel_state` API, so we need to set it manually. | |
parallel_state.sequence_parallel = config.model_parallel.sequence_parallel | |
if parallel_state.sequence_parallel: | |
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" | |
# Create the local job directory, save the config file, and pipe to a local log. | |
if distributed.is_rank0(): | |
os.makedirs(config.job.path_local, exist_ok=True) | |
# Save the config as .pkl for reproducibility. | |
LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") | |
# Save the config as .yaml for reading or parsing experiment hyperparameters. | |
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") | |
dist.barrier() | |
log.init_loguru_file(f"{config.job.path_local}/stdout.log") | |
if distributed.is_rank0(): | |
# Print important environment variables and the effective config. | |
log.info("Config:\n" + config.pretty_print(use_color=True)) | |
misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"]) | |
# Set the random seed. If multi-GPU, different ranks are set with different seeds. | |
misc.set_random_seed(seed=config.trainer.seed, by_rank=True) | |
# Initialize cuDNN. | |
torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic | |
torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark | |
# Floating-point precision settings. | |
torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True | |
# Initialize the callback functions. | |
self.callbacks = callback.CallBackGroup(config=config, trainer=self) | |
# Initialize the model checkpointer. | |
if config.checkpoint.type is None: | |
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) | |
else: | |
self.checkpointer: Checkpointer = instantiate( | |
config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks | |
) | |
# Initialize the timer for speed benchmarking. | |
self.training_timer = misc.TrainingTimer() | |
# Send a TimeoutError if a training step takes over timeout_period seconds. | |
signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore | |
def train( | |
self, | |
model: Model, | |
dataloader_train: torch.utils.data.DataLoader, | |
dataloader_val: torch.utils.data.DataLoader, | |
) -> None: | |
"""The training function. | |
Args: | |
model (Model): The PyTorch model. | |
dataloader_train (torch.utils.data.DataLoader): The training data loader. | |
dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
""" | |
# Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. | |
model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore | |
model.on_train_start(self.config.trainer.memory_format) | |
# Initialize the optimizer, scheduler, and grad_scaler. | |
self.callbacks.on_optimizer_init_start() | |
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) | |
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) | |
self.callbacks.on_optimizer_init_end() | |
# Load the model checkpoint and get the starting iteration number. | |
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) | |
grad_accum_iter = 0 | |
log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") | |
if self.config.trainer.distributed_parallelism == "ddp": | |
# Create a DDP model wrapper. | |
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) | |
elif self.config.trainer.distributed_parallelism == "fsdp": | |
model_ddp = model | |
else: | |
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") | |
log.info("Starting training...") | |
self.callbacks.on_train_start(model, iteration=iteration) | |
# Initial validation. | |
if self.config.trainer.run_validation and iteration == 0: | |
self.validate(model, dataloader_val, iteration=iteration) | |
_end_training = False | |
while True: | |
dataloader_train_iter = iter(dataloader_train) | |
while True: | |
self.callbacks.on_before_dataloading(iteration) | |
with self.training_timer("dataloader_train"): | |
try: | |
data_batch = next(dataloader_train_iter) | |
for k in data_batch.keys(): | |
if torch.is_tensor(data_batch[k]): | |
data_batch[k] = data_batch[k].cuda() | |
except StopIteration: | |
break | |
self.callbacks.on_after_dataloading(iteration) | |
# If max_iter is reached, exit the training loop. | |
if iteration >= self.config.trainer.max_iter: | |
_end_training = True | |
break | |
# Move all tensors in the data batch to GPU device. | |
data_batch = misc.to(data_batch, device="cuda") | |
# The actual training step. | |
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) | |
if not model.training: | |
model_ddp.train() | |
assert model_ddp.training, "model_ddp is not in training mode." | |
assert model.training, "model is not in training mode." | |
output_batch, loss, grad_accum_iter = self.training_step( | |
model_ddp, | |
optimizer, | |
scheduler, | |
grad_scaler, | |
data_batch, | |
iteration=iteration, | |
grad_accum_iter=grad_accum_iter, | |
) | |
# Do the following when an actual optimizer (update) step has been made. | |
iteration += 1 | |
# Save checkpoint. | |
if iteration % self.config.checkpoint.save_iter == 0: | |
async_saving = getattr(self.config.checkpoint, "async_saving", True) | |
self.checkpointer.save( | |
model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving | |
) | |
self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
# Validation. | |
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: | |
self.validate(model, dataloader_val, iteration=iteration) | |
# This iteration is successful; reset the timeout signal. | |
signal.alarm(self.config.trainer.timeout_period) | |
if _end_training: | |
break | |
log.success("Done with training.") | |
if iteration % self.config.checkpoint.save_iter != 0: | |
async_saving = getattr(self.config.checkpoint, "async_saving", True) | |
self.checkpointer.save( | |
model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving | |
) | |
self.callbacks.on_train_end(model, iteration=iteration) | |
self.checkpointer.finalize() | |
distributed.barrier() | |
self.callbacks.on_app_end() | |
def training_step( | |
self, | |
model_ddp: torch.nn.Module | distributed.DistributedDataParallel, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
grad_scaler: torch.amp.GradScaler, | |
data: dict[str, torch.Tensor], | |
iteration: int = 0, | |
grad_accum_iter: int = 0, | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: | |
"""The training step. | |
Args: | |
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare | |
module, depending on whether distributed training is enabled or not. | |
optimizer (torch.optim.Optimizer): The model optimizer. | |
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). | |
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). | |
iteration (int): Current iteration number. | |
grad_accum_iter (int): Number of gradient accumulation iterations. | |
Returns: | |
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). | |
loss (torch.Tensor): The total loss of the training data batch. | |
""" | |
# Only let DDP sync gradient at the last iteration of the gradient accumulation window | |
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): | |
with self.training_timer("forward"): | |
output_batch, loss = model_ddp.training_step(data, iteration) | |
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) | |
with self.training_timer("backward"): | |
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) | |
loss_scaled.backward() | |
if self.config.trainer.distributed_parallelism == "ddp": | |
model_ddp.module.on_after_backward() | |
else: | |
model_ddp.on_after_backward() | |
self.callbacks.on_after_backward(model_ddp, iteration=iteration) | |
grad_accum_iter += 1 | |
if grad_accum_iter == self.config.trainer.grad_accum_iter: | |
with self.training_timer("optimizer_step"): | |
self.callbacks.on_before_optimizer_step( | |
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration | |
) | |
grad_scaler.step(optimizer) | |
grad_scaler.update() | |
scheduler.step() | |
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) | |
if self.config.trainer.distributed_parallelism == "ddp": | |
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) | |
else: | |
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) | |
optimizer.zero_grad(set_to_none=True) | |
grad_accum_iter = 0 | |
return output_batch, loss, grad_accum_iter | |
def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: | |
"""Validate on the full validation dataset. | |
Args: | |
model (Model): The PyTorch model. | |
dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
iteration (int): Current iteration number. | |
""" | |
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) | |
model.eval() | |
# Evaluate on the full validation set. | |
with ema.ema_scope(model, enabled=model.config.ema.enabled): | |
for val_iter, data_batch in enumerate(dataloader_val): | |
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: | |
break | |
data_batch = misc.to(data_batch, device="cuda") | |
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) | |
output_batch, loss = model.validation_step(data_batch, iteration) | |
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
self.callbacks.on_validation_end(model, iteration=iteration) | |