Spaces:
Build error
Build error
File size: 5,480 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Set
import torch
from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer
from cosmos_predict1.utils import distributed, log
from cosmos_predict1.utils.model import Model
class Checkpointer(DDPCheckpointer):
"""
Checkpointer class for PEFT in distributed training. This class is similar to the DDP checkpointer,
with the exception that the `broadcast_via_filesystem` functionality is not supported, and it supports
loading pre-trained model without any postfix.
Note:
- Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.broadcast_via_filesystem:
raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.")
def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str:
"""
Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer)
to load pre-trained model without any postfix.
"""
checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
checkpoint_path = checkpoint_path.replace("model_model.pt", "model.pt")
return checkpoint_path
def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]:
"""
Load state_dict and broadcast for PEFT checkpointer.
This function is identical to the `load_broadcast_state_dict` function of the base class (DDP checkpointer),
with the exception that the `broadcast_via_filesystem` functionality is not supported.
Args:
checkpoint_path (str): The base path of the checkpoint.
model (Model): The model being loaded.
resume_keys (Set): Set of keys to resume from the checkpoint.
Returns:
dict[str, Any]: A dictionary containing the loaded state for each resumed key.
"""
state_dict = {}
sorted_resume_keys = sorted(resume_keys)
# Step 1: Download checkpoints for every GPU of DDP-rank 0 and CP-rank 0.
if self.rank_dp_w_cp == 0:
for key in sorted_resume_keys:
_ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path))
if os.path.exists(local_cache_path):
# If the local checkpoint exists, we can directly load it
self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...")
_state_dict = torch.load(
local_cache_path, map_location=lambda storage, loc: storage, weights_only=False
)
else:
# Pre-trained model is not in local cache, so we need to load it from the checkpoint path
self.print(f"Loading checkpoint from: {_ckpt_path}")
_state_dict = torch.load(_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False)
state_dict[key] = _state_dict
# Ensure all ranks wait for the download to complete
distributed.barrier()
# Step 2: Broadcast checkpoint data
log.info(
"Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.",
rank0_only=True,
)
for key in sorted_resume_keys:
if self.broadcast_via_filesystem:
# Load the checkpoint from the local filesystem for other ranks
if self.rank_dp_w_cp != 0:
_ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path))
if os.path.exists(local_cache_path):
self.print(f"Loading checkpoint from: {local_cache_path}")
state_dict[key] = torch.load(
local_cache_path, map_location=lambda storage, loc: storage, weights_only=False
)
else:
self.print(f"Loading checkpoint from: {_ckpt_path}")
state_dict[key] = torch.load(
_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False
)
else:
raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.")
return state_dict
|