File size: 19,094 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import gc
import os
import threading

import torch
from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType

from cosmos_predict1.utils import callback, distributed, log, misc
from cosmos_predict1.utils.config import CheckpointConfig, JobConfig
from cosmos_predict1.utils.easy_io import easy_io
from cosmos_predict1.utils.fsdp_optim_fix import scatter_full_optim_state_dict
from cosmos_predict1.utils.model import Model


class FSDPCheckpointer:
    """The checkpointer class. Supports checkpoint saving/loading to local disk."""

    def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
        """Constructor of the checkpointer.

        Args:
            config_checkpoint (CheckpointConfig): The config object for the checkpointer.
        """
        # Set the callback functions.
        self.callbacks = callbacks
        self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
        self.strict_resume = config_checkpoint.strict_resume
        self.load_path = config_checkpoint.load_path
        self.load_training_state = config_checkpoint.load_training_state
        self.save_thread = None
        self.config_checkpoint = config_checkpoint

    def _load_ckpt_file_during_init(self):
        latest_checkpoint_file = self._read_latest_checkpoint_file()
        if latest_checkpoint_file is not None:
            # 1. Resume training from latest_checkpoint.txt under the same name.
            checkpoint_dir = self.checkpoint_dir_local
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
            resume = True
            log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}")
            log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}")
            log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)")
        else:
            if self.load_path:
                # 2. Load the module weights specified by config_checkpoint.path.
                checkpoint_path = self.load_path
                resume = self.load_training_state
                log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}")
                if resume:
                    log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)")
                else:
                    log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)")
            else:
                # 3. Randomly initialize the model parameters and train from scratch.
                checkpoint_path = None
                resume = False
                log.critical("[Checkpoint] No checkpoint path specified")
                log.critical("[Checkpoint] Starting fresh training with random initialization")
        return checkpoint_path, resume

    @misc.timer("FSDP.load_model_during_init")
    def load_model_during_init(self, model, is_ema=False, ema_id: int = 0):
        if ema_id > 0:
            assert is_ema, "ema_id should be used with is_ema=True"
        checkpoint_path, _ = self._load_ckpt_file_during_init()
        if checkpoint_path is not None:
            tag = "reg" if not is_ema else "ema"
            default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt")
            if not os.path.exists(default_checkpoint_path):
                default_checkpoint_path = checkpoint_path  # starting from the release checkpoint
                log.warning(f"is_ema={is_ema} model is not found. Loading from {default_checkpoint_path}")
            if tag == "ema" and ema_id > 0:
                _checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt")
                _checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt")
                if self._check_checkpoint_exists(_checkpoint_path, is_raise=False):
                    default_checkpoint_path = _checkpoint_path
                else:
                    print(
                        f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} "
                        f"(fallback to {default_checkpoint_path})"
                    )
            checkpoint_path = default_checkpoint_path
            self._check_checkpoint_exists(checkpoint_path)

            log.info(f"Loading checkpoint (local): {checkpoint_path}")
            state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
            log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
            log.info("- Loading the model...")
            if self.strict_resume:
                log.info(model.load_state_dict(state_dict, strict=self.strict_resume))
            else:
                log.critical("\t Using non-strict model")
                from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model

                log.info(non_strict_load_model(model, state_dict))
            log.info("-finish model loading")
        else:
            log.info(f"is_ema={is_ema} model is not found and loaded.")

    @misc.timer("FSDP.load_optim_scheduler_during_init")
    def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler):
        checkpoint_path, resume = self._load_ckpt_file_during_init()
        log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}")
        if checkpoint_path is not None:
            if resume:
                checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt")
                self._check_checkpoint_exists(checkpoint_path)
                if distributed.get_rank() == 0:
                    log.info(f"Loading checkpoint (local): {checkpoint_path}")
                    state_dict = torch.load(
                        checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False
                    )
                    log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
                    log.info("- Loading the optimizer (FSDP scatter)...")
                else:
                    state_dict = {
                        "optimizer": None,
                        "scheduler": None,
                    }
                distributed.barrier()
                sharded_optimizer_state_dict = scatter_full_optim_state_dict(  # <---- FSDP
                    state_dict["optimizer"],
                    fsdp_model,
                )
                log.info("- Loading the optimizer (FSDP load_state_dict)...")
                log.info(optimizer.load_state_dict(sharded_optimizer_state_dict))
                log.critical("Skip loading the scheduler...")
                return
                log.info("- Loading the scheduler...")
                scheduler.load_state_dict(state_dict["scheduler"])

    @misc.timer("FSDP get_optim_scheduler_state")
    def get_optim_scheduler_state(self, optim, fsdp_model, scheduler):
        with FSDP.state_dict_type(
            fsdp_model,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
            FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
        ):
            optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
        scheduler_statedict = scheduler.state_dict()
        return {
            "optimizer": optim_statedict,
            "scheduler": scheduler_statedict,
        }

    def save(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler.LRScheduler,
        grad_scaler: torch.amp.GradScaler,
        iteration: int,
        async_saving: bool = True,
    ) -> None:
        """Save network weights, optimizer parameters, scheduler parameters to a checkpoint.

        Args:
            model (Model): The PyTorch model.
            optimizer (torch.optim.Optimizer): The model optimizer.
            scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
            grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
            iteration (int): Current iteration number.
        """
        self.callbacks.on_save_checkpoint_start(model, iteration)

        model_state_dict = model.state_dict_model()
        optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler)
        torch.cuda.empty_cache()
        state_dict = dict(
            iteration=iteration,
        )
        self.callbacks.on_save_checkpoint(model, state_dict=state_dict)

        postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix()
        if replicate_idx == 0 and shard_idx == 0:
            pass  # save whole; it is rank0
        elif replicate_idx < total_ema_num and shard_idx == 0:
            model_state_dict["model"] = None  # only save ema
            optim_scheduler_state_dict = None
            state_dict = None
        else:
            return

        checkpoint_file = f"iter_{iteration:09}{postfix}.pt"
        if async_saving:
            # Wait for previous saver thread to end.
            if self.save_thread:
                self.save_thread.join()
            # Run the checkpoint saver in a separate thread.
            self.save_thread = threading.Thread(
                target=self._save_worker_local,
                daemon=False,
                args=(
                    model_state_dict,
                    optim_scheduler_state_dict,
                    state_dict,
                    checkpoint_file,
                    distributed.get_rank(),
                ),
            )
            self.save_thread.start()
            log.info("checkpoint saving from an async thread")
        else:
            torch.cuda.empty_cache()
            # Run the checkpoint saver in the current thread.
            self._save_worker_local(
                model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank()
            )
            log.info("checkpoint saved within the main thread")
            del model_state_dict, optim_scheduler_state_dict, state_dict
            gc.collect()
            torch.cuda.empty_cache()
        self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)

    @misc.timer("checkpoint saving (local)")
    def _save_worker_local(
        self,
        model_state_dict: dict[str, torch.Tensor],
        optim_scheduler_state_dict: dict[str, torch.Tensor],
        state_dict: dict[str, torch.Tensor],
        checkpoint_file: str,
        rank: int = 0,
    ) -> None:
        """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).

        Args:
            state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
            checkpoint_file (str): The file name of the model checkpoint.
            rank (int): GPU device (default: 0).
        """
        checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
        os.makedirs(self.checkpoint_dir_local, exist_ok=True)
        try:
            model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"]
            if model_state_dict is not None:
                torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt"))
            if ema_model_state_dict is not None:
                torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt"))
            if optim_scheduler_state_dict is not None:
                torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt"))
            if state_dict is not None:
                torch.save(state_dict, checkpoint_path)
            if rank == 0:
                self._write_latest_checkpoint_file(checkpoint_file)
            log.success(f"Saved checkpoint (local): {checkpoint_path}")
            iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
            self.callbacks.on_save_checkpoint_success(iteration=iteration)
        except Exception as e:  # noqa: BLE001
            log.exception(f"Checkpoint failed to save (local): {e}")

    @misc.timer("checkpoint loading")
    def load(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer | None = None,
        scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
        grad_scaler: torch.amp.GradScaler | None = None,
    ) -> int:
        """Load network weights and optimizer states from a checkpoint in a single process.

        The priority of the checkpoint loading logic is:
        1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
        2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
           - This is typically used for inference mode.
           - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
        3. If none of the above, randomly initialize the model parameters and train from scratch.

        Args:
            model (FSDPDiffModle): The PyTorch model.
            optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
            scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
            grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).

        Returns:
            iteration (int): the iteration number to start/resume from.
        """
        self.callbacks.on_load_checkpoint_start(model)

        del optimizer, grad_scaler
        checkpoint_path, resume = self._load_ckpt_file_during_init()
        iteration = 0
        if checkpoint_path is not None:
            self._check_checkpoint_exists(checkpoint_path)
            log.info(f"Loading checkpoint (local): {checkpoint_path}")
            state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False)
            log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
            self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
            if resume:
                iteration = state_dict["iteration"]
            log.success("Done with loading the checkpoint.")
        else:
            log.info("Training from scratch.")
        torch.cuda.empty_cache()

        self.callbacks.on_load_checkpoint_end(model)

        if scheduler is not None:
            scheduler.last_epoch = iteration
            log.critical(f"resume scheduler from {iteration}", rank0_only=False)

        return iteration

    def _read_latest_checkpoint_file(self) -> str | None:
        """Get the file name of the latest saved checkpoint. If it doesn't exist, return None.

        Returns:
            checkpoint_file (str | None): file name of the latest saved checkpoint.
        """
        checkpoint_file = None
        latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
        if os.path.isfile(latest_path):
            checkpoint_file = open(latest_path).read().strip()
        if checkpoint_file is None:
            log.warning(f"Latest ckpt file not found: {latest_path}")
        else:
            log.info(f"Found latest checkpoint: {checkpoint_file}")
        return checkpoint_file

    def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
        """Track the file name of the latest saved checkpoint.

        Args:
            checkpoint_file (str): file name of the latest saved checkpoint.
        """
        content = f"{checkpoint_file}\n"
        latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
        with open(latest_path, "w") as file:
            file.write(content)

    def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None:
        """If the file checkpoint_path does not exist, raise an error.

        Args:
            checkpoint_path (str): full path to the checkpoint.
        """
        if not os.path.exists(checkpoint_path):
            if is_raise:
                raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
            return False
        return True

    def finalize(self) -> None:
        """Finalize the checkpointer."""
        if self.save_thread:
            self.save_thread.join()


class FSDPInferenceCheckpointer:
    def __init__(
        self,
        ckpt_path: str,
        strict_resume: bool = True,
    ):
        self.ckpt_path = ckpt_path
        self.strict_resume = strict_resume

    @misc.timer("FSDPInferenceCheckpointer.load_model_during_init")
    def load_model_during_init(self, model, is_ema=False, ema_id: int = 0):
        del ema_id
        if is_ema:
            log.warning("EMA model is not supported in inference mode.")
            return
        assert easy_io.exists(self.ckpt_path)
        log.info(f"Loading from {self.ckpt_path}")
        state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False)
        if self.strict_resume:
            log.info(model.load_state_dict(state_dict, strict=self.strict_resume))
        else:
            log.critical("\t Using non-strict model")
            from cosmos_predict1.diffusion.training.utils.checkpointer import non_strict_load_model

            log.info(non_strict_load_model(model, state_dict))
        log.info("-finish model loading")

    def load_optim_scheduler_during_init(self, *args, **kwargs):
        """
        We do not do load in inference mode. The function is here to maintain the same interface to avoid errors.
        """
        pass

    def save(self, *args, **kwargs):
        """
        We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors.
        """
        pass

    def load(self, *args, **kwargs):
        """
        We do not do load in inference mode. The function is here to maintain the same interface to avoid errors.
        """
        return 0