|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
import os |
|
import signal |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.utils.data |
|
from megatron.core import parallel_state |
|
|
|
from cosmos_transfer1.utils import callback, distributed, ema, log, misc |
|
from cosmos_transfer1.utils.checkpointer import Checkpointer |
|
from cosmos_transfer1.utils.lazy_config import LazyConfig, instantiate |
|
from cosmos_transfer1.utils.model import Model |
|
|
|
|
|
class Trainer: |
|
"""The base trainer class. |
|
|
|
All trainers should inherit Trainer. It contains the basic functionality for model training |
|
(particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), |
|
mixed-precision training (fp16/bf16). |
|
|
|
Attributes: |
|
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. |
|
training_timer (misc.Timer): Timer object to time code blocks and functions. |
|
""" |
|
|
|
def __init__(self, config): |
|
"""Constructor of the trainer. |
|
|
|
Args: |
|
config (Config): The config object for the codebase. |
|
""" |
|
super().__init__() |
|
self.config = config |
|
|
|
with misc.timer("init_distributed"): |
|
distributed.init() |
|
|
|
if hasattr(config.model, "context_parallel_size"): |
|
if config.model_parallel.context_parallel_size > 1: |
|
raise ValueError( |
|
"Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " |
|
"config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." |
|
) |
|
else: |
|
log.critical( |
|
"Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." |
|
) |
|
config.model_parallel.context_parallel_size = config.model.context_parallel_size |
|
parallel_state.initialize_model_parallel( |
|
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, |
|
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, |
|
context_parallel_size=config.model_parallel.context_parallel_size, |
|
) |
|
|
|
|
|
parallel_state.sequence_parallel = config.model_parallel.sequence_parallel |
|
if parallel_state.sequence_parallel: |
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" |
|
|
|
|
|
if distributed.is_rank0(): |
|
os.makedirs(config.job.path_local, exist_ok=True) |
|
|
|
LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") |
|
|
|
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") |
|
dist.barrier() |
|
log.init_loguru_file(f"{config.job.path_local}/stdout.log") |
|
if distributed.is_rank0(): |
|
|
|
log.info("Config:\n" + config.pretty_print(use_color=True)) |
|
misc.print_environ_variables(["OUTPUT_ROOT"]) |
|
|
|
misc.set_random_seed(seed=config.trainer.seed, by_rank=True) |
|
|
|
torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic |
|
torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark |
|
|
|
torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
self.callbacks = callback.CallBackGroup(config=config, trainer=self) |
|
|
|
if config.checkpoint.type is None: |
|
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) |
|
else: |
|
self.checkpointer: Checkpointer = instantiate( |
|
config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks |
|
) |
|
|
|
self.training_timer = misc.TrainingTimer() |
|
|
|
signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) |
|
|
|
def train( |
|
self, |
|
model: Model, |
|
dataloader_train: torch.utils.data.DataLoader, |
|
dataloader_val: torch.utils.data.DataLoader, |
|
) -> None: |
|
"""The training function. |
|
|
|
Args: |
|
model (Model): The PyTorch model. |
|
dataloader_train (torch.utils.data.DataLoader): The training data loader. |
|
dataloader_val (torch.utils.data.DataLoader): The validation data loader. |
|
""" |
|
|
|
model = model.to("cuda", memory_format=self.config.trainer.memory_format) |
|
model.on_train_start(self.config.trainer.memory_format) |
|
|
|
|
|
self.callbacks.on_optimizer_init_start() |
|
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) |
|
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) |
|
self.callbacks.on_optimizer_init_end() |
|
|
|
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) |
|
grad_accum_iter = 0 |
|
log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") |
|
if self.config.trainer.distributed_parallelism == "ddp": |
|
|
|
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) |
|
elif self.config.trainer.distributed_parallelism == "fsdp": |
|
model_ddp = model |
|
else: |
|
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") |
|
log.info("Starting training...") |
|
self.callbacks.on_train_start(model, iteration=iteration) |
|
|
|
if self.config.trainer.run_validation and iteration == 0: |
|
self.validate(model, dataloader_val, iteration=iteration) |
|
_end_training = False |
|
while True: |
|
dataloader_train_iter = iter(dataloader_train) |
|
while True: |
|
self.callbacks.on_before_dataloading(iteration) |
|
with self.training_timer("dataloader_train"): |
|
try: |
|
data_batch = next(dataloader_train_iter) |
|
except StopIteration: |
|
break |
|
self.callbacks.on_after_dataloading(iteration) |
|
|
|
if iteration >= self.config.trainer.max_iter: |
|
_end_training = True |
|
break |
|
|
|
data_batch = misc.to(data_batch, device="cuda") |
|
|
|
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) |
|
if not model.training: |
|
model_ddp.train() |
|
assert model_ddp.training, "model_ddp is not in training mode." |
|
assert model.training, "model is not in training mode." |
|
output_batch, loss, grad_accum_iter = self.training_step( |
|
model_ddp, |
|
optimizer, |
|
scheduler, |
|
grad_scaler, |
|
data_batch, |
|
iteration=iteration, |
|
grad_accum_iter=grad_accum_iter, |
|
) |
|
|
|
iteration += 1 |
|
|
|
if iteration % self.config.checkpoint.save_iter == 0: |
|
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) |
|
self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) |
|
|
|
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: |
|
self.validate(model, dataloader_val, iteration=iteration) |
|
|
|
signal.alarm(self.config.trainer.timeout_period) |
|
if _end_training: |
|
break |
|
log.success("Done with training.") |
|
if iteration % self.config.checkpoint.save_iter != 0: |
|
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) |
|
self.callbacks.on_train_end(model, iteration=iteration) |
|
self.checkpointer.finalize() |
|
log.info("Cleaning up distributed environment..") |
|
distributed.barrier() |
|
log.info("Cleaning up distributed environment.. Done!") |
|
self.callbacks.on_app_end() |
|
|
|
def training_step( |
|
self, |
|
model_ddp: torch.nn.Module | distributed.DistributedDataParallel, |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: torch.optim.lr_scheduler.LRScheduler, |
|
grad_scaler: torch.amp.GradScaler, |
|
data: dict[str, torch.Tensor], |
|
iteration: int = 0, |
|
grad_accum_iter: int = 0, |
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: |
|
"""The training step. |
|
|
|
Args: |
|
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare |
|
module, depending on whether distributed training is enabled or not. |
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). |
|
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). |
|
iteration (int): Current iteration number. |
|
grad_accum_iter (int): Number of gradient accumulation iterations. |
|
|
|
Returns: |
|
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). |
|
loss (torch.Tensor): The total loss of the training data batch. |
|
""" |
|
|
|
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): |
|
with self.training_timer("forward"): |
|
output_batch, loss = model_ddp.training_step(data, iteration) |
|
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) |
|
with self.training_timer("backward"): |
|
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) |
|
loss_scaled.backward() |
|
if self.config.trainer.distributed_parallelism == "ddp": |
|
model_ddp.module.on_after_backward() |
|
else: |
|
model_ddp.on_after_backward() |
|
self.callbacks.on_after_backward(model_ddp, iteration=iteration) |
|
grad_accum_iter += 1 |
|
if grad_accum_iter == self.config.trainer.grad_accum_iter: |
|
with self.training_timer("optimizer_step"): |
|
self.callbacks.on_before_optimizer_step( |
|
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration |
|
) |
|
grad_scaler.step(optimizer) |
|
grad_scaler.update() |
|
scheduler.step() |
|
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) |
|
if self.config.trainer.distributed_parallelism == "ddp": |
|
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) |
|
else: |
|
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) |
|
optimizer.zero_grad(set_to_none=True) |
|
grad_accum_iter = 0 |
|
return output_batch, loss, grad_accum_iter |
|
|
|
@torch.no_grad() |
|
def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: |
|
"""Validate on the full validation dataset. |
|
|
|
Args: |
|
model (Model): The PyTorch model. |
|
dataloader_val (torch.utils.data.DataLoader): The validation data loader. |
|
iteration (int): Current iteration number. |
|
""" |
|
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) |
|
model.eval() |
|
|
|
with ema.ema_scope(model, enabled=model.config.ema.enabled): |
|
for val_iter, data_batch in enumerate(dataloader_val): |
|
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: |
|
break |
|
data_batch = misc.to(data_batch, device="cuda") |
|
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) |
|
output_batch, loss = model.validation_step(data_batch, iteration) |
|
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) |
|
self.callbacks.on_validation_end(model, iteration=iteration) |
|
|