|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Set |
|
|
|
import torch |
|
|
|
from cosmos_transfer1.checkpointer.ddp_checkpointer import StateDictItemPath |
|
from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer |
|
from cosmos_transfer1.diffusion.training.models.model import DiffusionModel |
|
from cosmos_transfer1.utils import distributed, log, misc |
|
from cosmos_transfer1.utils.easy_io import easy_io |
|
|
|
|
|
class Checkpointer(TPCheckpointer): |
|
def load_broadcast_state_dict( |
|
self, checkpoint_path: str, model: DiffusionModel, resume_keys: Set |
|
) -> dict[str, Any]: |
|
""" |
|
Load state_dict and broadcast efficiently. |
|
|
|
This method optimizes checkpoint loading for distributed training for improved |
|
connection speed and reliability. |
|
|
|
The main steps are: |
|
1. Retrieve TP-rank-specific checkpoints for each GPU of DDP-rank 0 |
|
and CP-rank 0. |
|
2. Each rank loads its corresponding checkpoint either from a local cache or |
|
receives it via broadcast. |
|
|
|
This approach ensures that each MP (Model Parallelism) rank loads its specific |
|
part of the model, which is crucial for scenarios where different parts of the |
|
model are distributed across multiple GPUs. |
|
|
|
The method supports both Tensor Parallelism (TP) and standard Data Parallel (DP) |
|
training. For TP, each rank can efficiently load its specific checkpoint from S3. |
|
For standard DDP without TP, the default broadcast mechanism is used. |
|
|
|
Args: |
|
checkpoint_path (str): The base path of the checkpoint in S3. |
|
model (DiffusionModel): The model being loaded. |
|
resume_keys (Set): Set of keys to resume from the checkpoint. |
|
|
|
Returns: |
|
dict[str, Any]: A dictionary containing the loaded state for each resumed key. |
|
|
|
Note: |
|
This implementation has been tested and optimized for 4K GPU training jobs, |
|
showing significant improvements in connection speed and overall efficiency. |
|
""" |
|
state_dict = {} |
|
sorted_resume_keys = sorted(resume_keys) |
|
for key in sorted_resume_keys: |
|
_ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) |
|
_state_dict = easy_io.load(_ckpt_path, weights_only=False) |
|
state_dict[key] = _state_dict |
|
self.print(f"Loaded checkpoint from: {_ckpt_path}") |
|
distributed.barrier() |
|
return state_dict |
|
|
|
@misc.timer("checkpoint saving") |
|
def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: |
|
""" |
|
similar to the original _save_worker, but with the following changes: |
|
* fast_backend=False to avoid high CPU usage |
|
""" |
|
try: |
|
for key, item in state_dict.items(): |
|
self.print(f"Saving {key} to {item.save_path}") |
|
try: |
|
easy_io.dump( |
|
item.state_dict, |
|
item.save_path, |
|
|
|
) |
|
self.print(f"Saved {key} to {item.save_path}") |
|
except Exception as e: |
|
self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") |
|
raise |
|
|
|
|
|
if self.mp_world_size > 1: |
|
torch.distributed.barrier(group=self.mp_gloo_pg) |
|
|
|
|
|
if self.mp_rank == 0 and self.rank_dp_w_cp == 0: |
|
self._write_latest_checkpoint_file(checkpoint_file) |
|
|
|
if distributed.get_rank() == 0: |
|
if "trained_data_record" in state_dict["model"].state_dict: |
|
self._write_trained_data_record( |
|
checkpoint_file, state_dict["model"].state_dict["trained_data_record"] |
|
) |
|
|
|
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) |
|
self.callbacks.on_save_checkpoint_success(iteration=iteration) |
|
except Exception as e: |
|
log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) |
|
|