roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import signal
import torch
import torch.distributed as dist
import torch.utils.data
from megatron.core import parallel_state
from cosmos_predict1.utils import callback, distributed, ema, log, misc
from cosmos_predict1.utils.checkpointer import Checkpointer
from cosmos_predict1.utils.lazy_config import LazyConfig, instantiate
from cosmos_predict1.utils.model import Model
class Trainer:
"""The base trainer class.
All trainers should inherit Trainer. It contains the basic functionality for model training
(particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
mixed-precision training (fp16/bf16).
Attributes:
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
training_timer (misc.Timer): Timer object to time code blocks and functions.
"""
def __init__(self, config):
"""Constructor of the trainer.
Args:
config (Config): The config object for the codebase.
"""
super().__init__()
self.config = config
# Set up the distributed computing environment.
with misc.timer("init_distributed"):
distributed.init()
# Set up parallel states.
if hasattr(config.model, "context_parallel_size"):
if config.model_parallel.context_parallel_size > 1:
raise ValueError(
"Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
"config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
)
else:
log.critical(
"Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
)
config.model_parallel.context_parallel_size = config.model.context_parallel_size
parallel_state.initialize_model_parallel(
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
context_parallel_size=config.model_parallel.context_parallel_size,
)
# `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
# It is not part of the original `parallel_state` API, so we need to set it manually.
parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
if parallel_state.sequence_parallel:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# Create the local job directory, save the config file, and pipe to a local log.
if distributed.is_rank0():
os.makedirs(config.job.path_local, exist_ok=True)
# Save the config as .pkl for reproducibility.
LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
# Save the config as .yaml for reading or parsing experiment hyperparameters.
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
dist.barrier()
log.init_loguru_file(f"{config.job.path_local}/stdout.log")
if distributed.is_rank0():
# Print important environment variables and the effective config.
log.info("Config:\n" + config.pretty_print(use_color=True))
misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"])
# Set the random seed. If multi-GPU, different ranks are set with different seeds.
misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
# Initialize cuDNN.
torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
# Floating-point precision settings.
torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
# Initialize the callback functions.
self.callbacks = callback.CallBackGroup(config=config, trainer=self)
# Initialize the model checkpointer.
if config.checkpoint.type is None:
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
else:
self.checkpointer: Checkpointer = instantiate(
config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
)
# Initialize the timer for speed benchmarking.
self.training_timer = misc.TrainingTimer()
# Send a TimeoutError if a training step takes over timeout_period seconds.
signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
def train(
self,
model: Model,
dataloader_train: torch.utils.data.DataLoader,
dataloader_val: torch.utils.data.DataLoader,
) -> None:
"""The training function.
Args:
model (Model): The PyTorch model.
dataloader_train (torch.utils.data.DataLoader): The training data loader.
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
"""
# Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
model.on_train_start(self.config.trainer.memory_format)
# Initialize the optimizer, scheduler, and grad_scaler.
self.callbacks.on_optimizer_init_start()
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
self.callbacks.on_optimizer_init_end()
# Load the model checkpoint and get the starting iteration number.
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
grad_accum_iter = 0
log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
if self.config.trainer.distributed_parallelism == "ddp":
# Create a DDP model wrapper.
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
elif self.config.trainer.distributed_parallelism == "fsdp":
model_ddp = model
else:
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
log.info("Starting training...")
self.callbacks.on_train_start(model, iteration=iteration)
# Initial validation.
if self.config.trainer.run_validation and iteration == 0:
self.validate(model, dataloader_val, iteration=iteration)
_end_training = False
while True:
dataloader_train_iter = iter(dataloader_train)
while True:
self.callbacks.on_before_dataloading(iteration)
with self.training_timer("dataloader_train"):
try:
data_batch = next(dataloader_train_iter)
for k in data_batch.keys():
if torch.is_tensor(data_batch[k]):
data_batch[k] = data_batch[k].cuda()
except StopIteration:
break
self.callbacks.on_after_dataloading(iteration)
# If max_iter is reached, exit the training loop.
if iteration >= self.config.trainer.max_iter:
_end_training = True
break
# Move all tensors in the data batch to GPU device.
data_batch = misc.to(data_batch, device="cuda")
# The actual training step.
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
if not model.training:
model_ddp.train()
assert model_ddp.training, "model_ddp is not in training mode."
assert model.training, "model is not in training mode."
output_batch, loss, grad_accum_iter = self.training_step(
model_ddp,
optimizer,
scheduler,
grad_scaler,
data_batch,
iteration=iteration,
grad_accum_iter=grad_accum_iter,
)
# Do the following when an actual optimizer (update) step has been made.
iteration += 1
# Save checkpoint.
if iteration % self.config.checkpoint.save_iter == 0:
async_saving = getattr(self.config.checkpoint, "async_saving", True)
self.checkpointer.save(
model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving
)
self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
# Validation.
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
self.validate(model, dataloader_val, iteration=iteration)
# This iteration is successful; reset the timeout signal.
signal.alarm(self.config.trainer.timeout_period)
if _end_training:
break
log.success("Done with training.")
if iteration % self.config.checkpoint.save_iter != 0:
async_saving = getattr(self.config.checkpoint, "async_saving", True)
self.checkpointer.save(
model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving
)
self.callbacks.on_train_end(model, iteration=iteration)
self.checkpointer.finalize()
distributed.barrier()
self.callbacks.on_app_end()
def training_step(
self,
model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
grad_scaler: torch.amp.GradScaler,
data: dict[str, torch.Tensor],
iteration: int = 0,
grad_accum_iter: int = 0,
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
"""The training step.
Args:
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
module, depending on whether distributed training is enabled or not.
optimizer (torch.optim.Optimizer): The model optimizer.
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
iteration (int): Current iteration number.
grad_accum_iter (int): Number of gradient accumulation iterations.
Returns:
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
loss (torch.Tensor): The total loss of the training data batch.
"""
# Only let DDP sync gradient at the last iteration of the gradient accumulation window
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
with self.training_timer("forward"):
output_batch, loss = model_ddp.training_step(data, iteration)
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
with self.training_timer("backward"):
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
loss_scaled.backward()
if self.config.trainer.distributed_parallelism == "ddp":
model_ddp.module.on_after_backward()
else:
model_ddp.on_after_backward()
self.callbacks.on_after_backward(model_ddp, iteration=iteration)
grad_accum_iter += 1
if grad_accum_iter == self.config.trainer.grad_accum_iter:
with self.training_timer("optimizer_step"):
self.callbacks.on_before_optimizer_step(
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
)
grad_scaler.step(optimizer)
grad_scaler.update()
scheduler.step()
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
if self.config.trainer.distributed_parallelism == "ddp":
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
else:
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
optimizer.zero_grad(set_to_none=True)
grad_accum_iter = 0
return output_batch, loss, grad_accum_iter
@torch.no_grad()
def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
"""Validate on the full validation dataset.
Args:
model (Model): The PyTorch model.
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
iteration (int): Current iteration number.
"""
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
model.eval()
# Evaluate on the full validation set.
with ema.ema_scope(model, enabled=model.config.ema.enabled):
for val_iter, data_batch in enumerate(dataloader_val):
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
break
data_batch = misc.to(data_batch, device="cuda")
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
output_batch, loss = model.validation_step(data_batch, iteration)
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
self.callbacks.on_validation_end(model, iteration=iteration)