harry900000's picture
add cosmos-tranfer1/ into repo
226c7c9
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Set
import torch
from cosmos_transfer1.checkpointer.ddp_checkpointer import StateDictItemPath
from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer
from cosmos_transfer1.diffusion.training.models.model import DiffusionModel
from cosmos_transfer1.utils import distributed, log, misc
from cosmos_transfer1.utils.easy_io import easy_io
class Checkpointer(TPCheckpointer):
def load_broadcast_state_dict(
self, checkpoint_path: str, model: DiffusionModel, resume_keys: Set
) -> dict[str, Any]:
"""
Load state_dict and broadcast efficiently.
This method optimizes checkpoint loading for distributed training for improved
connection speed and reliability.
The main steps are:
1. Retrieve TP-rank-specific checkpoints for each GPU of DDP-rank 0
and CP-rank 0.
2. Each rank loads its corresponding checkpoint either from a local cache or
receives it via broadcast.
This approach ensures that each MP (Model Parallelism) rank loads its specific
part of the model, which is crucial for scenarios where different parts of the
model are distributed across multiple GPUs.
The method supports both Tensor Parallelism (TP) and standard Data Parallel (DP)
training. For TP, each rank can efficiently load its specific checkpoint from S3.
For standard DDP without TP, the default broadcast mechanism is used.
Args:
checkpoint_path (str): The base path of the checkpoint in S3.
model (DiffusionModel): The model being loaded.
resume_keys (Set): Set of keys to resume from the checkpoint.
Returns:
dict[str, Any]: A dictionary containing the loaded state for each resumed key.
Note:
This implementation has been tested and optimized for 4K GPU training jobs,
showing significant improvements in connection speed and overall efficiency.
"""
state_dict = {}
sorted_resume_keys = sorted(resume_keys)
for key in sorted_resume_keys:
_ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
_state_dict = easy_io.load(_ckpt_path, weights_only=False)
state_dict[key] = _state_dict
self.print(f"Loaded checkpoint from: {_ckpt_path}")
distributed.barrier()
return state_dict
@misc.timer("checkpoint saving")
def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None:
"""
similar to the original _save_worker, but with the following changes:
* fast_backend=False to avoid high CPU usage
"""
try:
for key, item in state_dict.items():
self.print(f"Saving {key} to {item.save_path}")
try:
easy_io.dump(
item.state_dict,
item.save_path,
# fast_backend=False, # too cpu heavy
)
self.print(f"Saved {key} to {item.save_path}")
except Exception as e:
self.print(f"Failed to save {key} to {item.save_path}: {str(e)}")
raise # Re-raise the exception after logging
# Synchronize only rank 0 of each model parallel group
if self.mp_world_size > 1:
torch.distributed.barrier(group=self.mp_gloo_pg)
# Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt
if self.mp_rank == 0 and self.rank_dp_w_cp == 0:
self._write_latest_checkpoint_file(checkpoint_file)
if distributed.get_rank() == 0: # only rank 0 saves trained_data_record
if "trained_data_record" in state_dict["model"].state_dict:
self._write_trained_data_record(
checkpoint_file, state_dict["model"].state_dict["trained_data_record"]
)
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
self.callbacks.on_save_checkpoint_success(iteration=iteration)
except Exception as e: # noqa: BLE001
log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose)