roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import threading
from typing import List, NamedTuple, Tuple
import torch
from cosmos_predict1.utils import distributed, log, misc
from cosmos_predict1.utils.checkpointer import Checkpointer as BaseCheckpointer
from cosmos_predict1.utils.model import Model
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 11):
from torch.ao import quantization
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
elif (
TORCH_VERSION >= (1, 8)
and hasattr(torch.quantization, "FakeQuantizeBase")
and hasattr(torch.quantization, "ObserverBase")
):
from torch import quantization
from torch.quantization import FakeQuantizeBase, ObserverBase
class _IncompatibleKeys(
NamedTuple(
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
],
)
):
pass
class MultiRankCheckpointer(BaseCheckpointer):
def save(
self,
model: Model,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
grad_scaler: torch.amp.GradScaler,
iteration: int,
) -> None:
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
Args:
model (Model): The PyTorch model.
optimizer (torch.optim.Optimizer): The model optimizer.
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
iteration (int): Current iteration number.
"""
# checkpoint_file = f"iter_{iteration:09}.pt"
postfix, _, total_ema_num = model.get_ckpt_postfix()
checkpoint_file = f"iter_{iteration:09}{postfix}.pt"
save_ranks = list(range(total_ema_num))
for _rank in save_ranks:
if distributed.get_rank() == _rank:
state_dict = dict(
model=model.state_dict(),
optimizer=optimizer.state_dict(),
scheduler=scheduler.state_dict(),
grad_scaler=grad_scaler.state_dict(),
iteration=iteration,
)
state_dict = misc.to(state_dict, device="cpu")
self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
# Wait for previous saver thread to end.
if self.save_thread:
self.save_thread.join()
# Run the checkpoint saver in a separate thread.
self.save_thread = threading.Thread(
target=self._save_worker_local,
daemon=False,
args=(state_dict, checkpoint_file, distributed.get_rank()),
)
self.save_thread.start()
@misc.timer("checkpoint loading")
def load(
self,
model: Model,
optimizer: torch.optim.Optimizer | None = None,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
grad_scaler: torch.amp.GradScaler | None = None,
) -> int:
"""Load network weights and optimizer states from a checkpoint in a single process.
The priority of the checkpoint loading logic is:
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
- This is typically used for inference mode.
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
3. If none of the above, randomly initialize the model parameters and train from scratch.
Args:
model (Model): The PyTorch model.
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
Returns:
iteration (int): the iteration number to start/resume from.
"""
latest_checkpoint_file = self._read_latest_checkpoint_file()
if latest_checkpoint_file is not None:
# different from base checkpointer, this support multi-EMA
postfix, _, total_ema_num = model.get_ckpt_postfix()
latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt")
# 1. Resume training from latest_checkpoint.txt under the same name.
checkpoint_dir = self.checkpoint_dir_local
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
resume = True
else:
if self.load_path:
# 2. Load the module weights specified by config_checkpoint.path.
checkpoint_path = self.load_path
# different from base checkpointer, this support multi-EMA
postfix, _, total_ema_num = model.get_ckpt_postfix()
checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt")
resume = self.load_training_state
else:
# 3. Randomly initialize the model parameters and train from scratch.
checkpoint_path = None
resume = False
# Load checkpoint.
if checkpoint_path is not None:
self._check_checkpoint_exists(checkpoint_path)
log.info(f"Loading checkpoint (local): {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
# Load the state dicts.
log.info("- Loading the model...")
log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume))
if resume:
iteration = state_dict["iteration"]
assert optimizer and scheduler
log.info("- Loading the optimizer...")
optimizer.load_state_dict(state_dict["optimizer"])
log.info("- Loading the scheduler...")
scheduler.load_state_dict(state_dict["scheduler"])
scheduler.last_epoch = iteration
log.info("- Loading the gradient scaler...")
grad_scaler.load_state_dict(state_dict["grad_scaler"])
log.success(f"Done with loading the checkpoint (iteration {iteration}).")
else:
iteration = 0
log.success("Done with loading the checkpoint.")
else:
# Checkpoint not found and not specified. We will train everything from scratch.
iteration = 0
log.info("Training from scratch.")
torch.cuda.empty_cache()
return iteration
# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
# workaround https://github.com/pytorch/pytorch/issues/24139
model_state_dict = model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
continue
model_param = model_state_dict[k]
# Allow mismatch for uninitialized parameters
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
continue
if not isinstance(model_param, torch.Tensor):
raise ValueError(
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
)
shape_model = tuple(model_param.shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
has_observer_base_classes = (
TORCH_VERSION >= (1, 8)
and hasattr(quantization, "ObserverBase")
and hasattr(quantization, "FakeQuantizeBase")
)
if has_observer_base_classes:
# Handle the special case of quantization per channel observers,
# where buffer shape mismatches are expected.
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
# foo.bar.param_or_buffer_name -> [foo, bar]
key_parts = key.split(".")[:-1]
cur_module = model
for key_part in key_parts:
cur_module = getattr(cur_module, key_part)
return cur_module
cls_to_skip = (
ObserverBase,
FakeQuantizeBase,
)
target_module = _get_module_for_key(model, k)
if isinstance(target_module, cls_to_skip):
# Do not remove modules with expected shape mismatches
# them from the state_dict loading. They have special logic
# in _load_from_state_dict to handle the mismatches.
continue
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
# Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
return _IncompatibleKeys(
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
incorrect_shapes=incorrect_shapes,
)