File size: 5,480 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Set

import torch

from cosmos_predict1.checkpointer.ddp import Checkpointer as DDPCheckpointer
from cosmos_predict1.utils import distributed, log
from cosmos_predict1.utils.model import Model


class Checkpointer(DDPCheckpointer):
    """
    Checkpointer class for PEFT in distributed training. This class is similar to the DDP checkpointer,
    with the exception that the `broadcast_via_filesystem` functionality is not supported, and it supports
    loading pre-trained model without any postfix.

    Note:
    - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not self.broadcast_via_filesystem:
            raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.")

    def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str:
        """
        Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer)
        to load pre-trained model without any postfix.
        """
        checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
        checkpoint_path = checkpoint_path.replace("model_model.pt", "model.pt")
        return checkpoint_path

    def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]:
        """
        Load state_dict and broadcast for PEFT checkpointer.

        This function is identical to the `load_broadcast_state_dict` function of the base class (DDP checkpointer),
        with the exception that the `broadcast_via_filesystem` functionality is not supported.

        Args:
            checkpoint_path (str): The base path of the checkpoint.
            model (Model): The model being loaded.
            resume_keys (Set): Set of keys to resume from the checkpoint.

        Returns:
            dict[str, Any]: A dictionary containing the loaded state for each resumed key.
        """
        state_dict = {}
        sorted_resume_keys = sorted(resume_keys)
        # Step 1: Download checkpoints for every GPU of DDP-rank 0 and CP-rank 0.
        if self.rank_dp_w_cp == 0:
            for key in sorted_resume_keys:
                _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
                local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path))
                if os.path.exists(local_cache_path):
                    # If the local checkpoint exists, we can directly load it
                    self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...")
                    _state_dict = torch.load(
                        local_cache_path, map_location=lambda storage, loc: storage, weights_only=False
                    )
                else:
                    # Pre-trained model is not in local cache, so we need to load it from the checkpoint path
                    self.print(f"Loading checkpoint from: {_ckpt_path}")
                    _state_dict = torch.load(_ckpt_path, map_location=lambda storage, loc: storage, weights_only=False)
                state_dict[key] = _state_dict

        # Ensure all ranks wait for the download to complete
        distributed.barrier()

        # Step 2: Broadcast checkpoint data
        log.info(
            "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.",
            rank0_only=True,
        )
        for key in sorted_resume_keys:
            if self.broadcast_via_filesystem:
                # Load the checkpoint from the local filesystem for other ranks
                if self.rank_dp_w_cp != 0:
                    _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model)
                    local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path))
                    if os.path.exists(local_cache_path):
                        self.print(f"Loading checkpoint from: {local_cache_path}")
                        state_dict[key] = torch.load(
                            local_cache_path, map_location=lambda storage, loc: storage, weights_only=False
                        )
                    else:
                        self.print(f"Loading checkpoint from: {_ckpt_path}")
                        state_dict[key] = torch.load(
                            _ckpt_path, map_location=lambda storage, loc: storage, weights_only=False
                        )

            else:
                raise ValueError("self.broadcast_via_filesystem=False is not implemented for PEFT checkpointer.")

        return state_dict