# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import signal import torch import torch.distributed as dist import torch.utils.data from megatron.core import parallel_state from cosmos_predict1.checkpointer.tp import Checkpointer as TensorParallelCheckpointer from cosmos_predict1.utils import distributed, ema, log, misc from cosmos_predict1.utils.checkpointer import Checkpointer from cosmos_predict1.utils.fsdp_checkpointer import FSDPCheckpointer from cosmos_predict1.utils.model import Model from cosmos_predict1.utils.trainer import Trainer class Trainer(Trainer): def __init__(self, config): super(Trainer, self).__init__(config) if config.trainer.distributed_parallelism == "ddp": if parallel_state.get_tensor_model_parallel_world_size() > 1: self.checkpointer = TensorParallelCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) log.critical("Using Tensor Parallelism Checkpointer") else: self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) elif config.trainer.distributed_parallelism == "fsdp": self.checkpointer = FSDPCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) else: raise ValueError(f"Unsupported distributed parallelism: {config.trainer.distributed_parallelism}") """ Modify the original trainer to log average loss (averaging across all devices and gradient accumulation) """ def train( self, model: Model, dataloader_train: torch.utils.data.DataLoader, dataloader_val: torch.utils.data.DataLoader, ) -> None: """The training function. Args: model (Model): The PyTorch model. dataloader_train (torch.utils.data.DataLoader): The training data loader. dataloader_val (torch.utils.data.DataLoader): The validation data loader. """ # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore log.info(f"Model Architecture:\n {model}") model.on_train_start(self.config.trainer.memory_format) # Initialize the optimizer and scheduler. self.callbacks.on_optimizer_init_start() optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) self.callbacks.on_optimizer_init_end() # Load the model checkpoint and get the starting iteration number. iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) # Set the scheduler to the current iteration. scheduler.last_epoch = iteration scheduler._step_count = iteration + 1 grad_accum_iter = 0 log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") if self.config.trainer.distributed_parallelism == "ddp": # Create a DDP model wrapper. model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) elif self.config.trainer.distributed_parallelism == "fsdp": model_ddp = model else: raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") log.info("Starting training...") self.callbacks.on_train_start(model, iteration=iteration) # Initial validation. if self.config.trainer.run_validation and iteration == 0: self.validate(model, dataloader_val, iteration=iteration) _end_training = False self.callbacks.on_before_dataloading(iteration) accumulated_loss = 0.0 while True: dataloader_train_iter = iter(dataloader_train) while True: self.callbacks.on_before_dataloading(iteration) try: data_batch = next(dataloader_train_iter) except StopIteration: break self.callbacks.on_after_dataloading(iteration) # If max_iter is reached, exit the training loop. if iteration >= self.config.trainer.max_iter: _end_training = True break # Move all tensors in the data batch to GPU device. data_batch = misc.to(data_batch, device="cuda") # The actual training step. self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) model_ddp.train() output_batch, loss, grad_accum_iter = self.training_step( model_ddp, optimizer, scheduler, grad_scaler, data_batch, iteration=iteration, grad_accum_iter=grad_accum_iter, ) # Accumulate loss accumulated_loss += loss.detach() # If the gradients are still being accumulated, continue to load the next training batch. if grad_accum_iter != 0: if self.enable_one_logger: # Callback for skipped OneLoggerCallback.on_training_step_end() self.one_logger.on_train_batch_end(set_barrier=False) continue # Do the following when an actual optimizer (update) step has been made. iteration += 1 # Average loss over accumulation steps grad_accum_avg_loss = accumulated_loss / self.config.trainer.grad_accum_iter # Average loss across all devices device_avg_loss = grad_accum_avg_loss.clone() dist.all_reduce(device_avg_loss, op=dist.ReduceOp.SUM) device_avg_loss /= dist.get_world_size() # Reset accumulation variables accumulated_loss = 0.0 self.callbacks.on_training_step_end( model, data_batch, output_batch, device_avg_loss, iteration=iteration ) # self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) # Validation. if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: self.validate(model, dataloader_val, iteration=iteration) # Save checkpoint. if iteration % self.config.checkpoint.save_iter == 0: self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) # This iteration is successful; reset the timeout signal. signal.alarm(self.config.trainer.timeout_period) if _end_training: break log.success("Done with training.") self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) self.callbacks.on_train_end(model, iteration=iteration) self.checkpointer.finalize() distributed.barrier() self.callbacks.on_app_end() def training_step( self, model_ddp: torch.nn.Module | distributed.DistributedDataParallel, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, grad_scaler: torch.amp.GradScaler, data: dict[str, torch.Tensor], iteration: int = 0, grad_accum_iter: int = 0, ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: """The training step. Args: model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare module, depending on whether distributed training is enabled or not. optimizer (torch.optim.Optimizer): The model optimizer. scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). iteration (int): Current iteration number. grad_accum_iter (int): Number of gradient accumulation iterations. Returns: output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). loss (torch.Tensor): The total loss of the training data batch. """ # Only let DDP sync gradient at the last iteration of the gradient accumulation window with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): with self.training_timer("forward"): output_batch, loss = model_ddp.training_step(data, iteration) self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) with self.training_timer("backward"): loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) loss_scaled.backward() if self.config.trainer.distributed_parallelism == "ddp": model_ddp.module.on_after_backward() else: model_ddp.on_after_backward() self.callbacks.on_after_backward(model_ddp, iteration=iteration) grad_accum_iter += 1 if grad_accum_iter == self.config.trainer.grad_accum_iter: with self.training_timer("optimizer_step"): self.callbacks.on_before_optimizer_step( model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration ) grad_scaler.step(optimizer) grad_scaler.update() scheduler.step() self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) if self.config.trainer.distributed_parallelism == "ddp": model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) else: model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) optimizer.zero_grad(set_to_none=True) grad_accum_iter = 0 return output_batch, loss, grad_accum_iter @torch.no_grad() def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: """Validate on the full validation dataset. Args: model (Model): The PyTorch model. dataloader_val (torch.utils.data.DataLoader): The validation data loader. iteration (int): Current iteration number. """ self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) model.eval() # Evaluate on the full validation set. with ema.ema_scope(model, enabled=getattr(model.config.ema, "enabled", False)): for val_iter, data_batch in enumerate(dataloader_val): if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: break data_batch = misc.to(data_batch, device="cuda") self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) output_batch, loss = model.validation_step(data_batch, iteration) self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) self.callbacks.on_validation_end(model, iteration=iteration)