Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import os | |
import threading | |
from typing import List, NamedTuple, Tuple | |
import torch | |
from cosmos_predict1.utils import distributed, log, misc | |
from cosmos_predict1.utils.checkpointer import Checkpointer as BaseCheckpointer | |
from cosmos_predict1.utils.model import Model | |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) | |
if TORCH_VERSION >= (1, 11): | |
from torch.ao import quantization | |
from torch.ao.quantization import FakeQuantizeBase, ObserverBase | |
elif ( | |
TORCH_VERSION >= (1, 8) | |
and hasattr(torch.quantization, "FakeQuantizeBase") | |
and hasattr(torch.quantization, "ObserverBase") | |
): | |
from torch import quantization | |
from torch.quantization import FakeQuantizeBase, ObserverBase | |
class _IncompatibleKeys( | |
NamedTuple( | |
"IncompatibleKeys", | |
[ | |
("missing_keys", List[str]), | |
("unexpected_keys", List[str]), | |
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), | |
], | |
) | |
): | |
pass | |
class MultiRankCheckpointer(BaseCheckpointer): | |
def save( | |
self, | |
model: Model, | |
optimizer: torch.optim.Optimizer, | |
scheduler: torch.optim.lr_scheduler.LRScheduler, | |
grad_scaler: torch.amp.GradScaler, | |
iteration: int, | |
) -> None: | |
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint. | |
Args: | |
model (Model): The PyTorch model. | |
optimizer (torch.optim.Optimizer): The model optimizer. | |
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). | |
iteration (int): Current iteration number. | |
""" | |
# checkpoint_file = f"iter_{iteration:09}.pt" | |
postfix, _, total_ema_num = model.get_ckpt_postfix() | |
checkpoint_file = f"iter_{iteration:09}{postfix}.pt" | |
save_ranks = list(range(total_ema_num)) | |
for _rank in save_ranks: | |
if distributed.get_rank() == _rank: | |
state_dict = dict( | |
model=model.state_dict(), | |
optimizer=optimizer.state_dict(), | |
scheduler=scheduler.state_dict(), | |
grad_scaler=grad_scaler.state_dict(), | |
iteration=iteration, | |
) | |
state_dict = misc.to(state_dict, device="cpu") | |
self.callbacks.on_save_checkpoint(model, state_dict=state_dict) | |
# Wait for previous saver thread to end. | |
if self.save_thread: | |
self.save_thread.join() | |
# Run the checkpoint saver in a separate thread. | |
self.save_thread = threading.Thread( | |
target=self._save_worker_local, | |
daemon=False, | |
args=(state_dict, checkpoint_file, distributed.get_rank()), | |
) | |
self.save_thread.start() | |
def load( | |
self, | |
model: Model, | |
optimizer: torch.optim.Optimizer | None = None, | |
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, | |
grad_scaler: torch.amp.GradScaler | None = None, | |
) -> int: | |
"""Load network weights and optimizer states from a checkpoint in a single process. | |
The priority of the checkpoint loading logic is: | |
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. | |
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. | |
- This is typically used for inference mode. | |
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. | |
3. If none of the above, randomly initialize the model parameters and train from scratch. | |
Args: | |
model (Model): The PyTorch model. | |
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). | |
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). | |
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). | |
Returns: | |
iteration (int): the iteration number to start/resume from. | |
""" | |
latest_checkpoint_file = self._read_latest_checkpoint_file() | |
if latest_checkpoint_file is not None: | |
# different from base checkpointer, this support multi-EMA | |
postfix, _, total_ema_num = model.get_ckpt_postfix() | |
latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt") | |
# 1. Resume training from latest_checkpoint.txt under the same name. | |
checkpoint_dir = self.checkpoint_dir_local | |
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) | |
resume = True | |
else: | |
if self.load_path: | |
# 2. Load the module weights specified by config_checkpoint.path. | |
checkpoint_path = self.load_path | |
# different from base checkpointer, this support multi-EMA | |
postfix, _, total_ema_num = model.get_ckpt_postfix() | |
checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt") | |
resume = self.load_training_state | |
else: | |
# 3. Randomly initialize the model parameters and train from scratch. | |
checkpoint_path = None | |
resume = False | |
# Load checkpoint. | |
if checkpoint_path is not None: | |
self._check_checkpoint_exists(checkpoint_path) | |
log.info(f"Loading checkpoint (local): {checkpoint_path}") | |
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) | |
log.success(f"Complete loading checkpoint (local): {checkpoint_path}") | |
self.callbacks.on_load_checkpoint(model, state_dict=state_dict) | |
# Load the state dicts. | |
log.info("- Loading the model...") | |
log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume)) | |
if resume: | |
iteration = state_dict["iteration"] | |
assert optimizer and scheduler | |
log.info("- Loading the optimizer...") | |
optimizer.load_state_dict(state_dict["optimizer"]) | |
log.info("- Loading the scheduler...") | |
scheduler.load_state_dict(state_dict["scheduler"]) | |
scheduler.last_epoch = iteration | |
log.info("- Loading the gradient scaler...") | |
grad_scaler.load_state_dict(state_dict["grad_scaler"]) | |
log.success(f"Done with loading the checkpoint (iteration {iteration}).") | |
else: | |
iteration = 0 | |
log.success("Done with loading the checkpoint.") | |
else: | |
# Checkpoint not found and not specified. We will train everything from scratch. | |
iteration = 0 | |
log.info("Training from scratch.") | |
torch.cuda.empty_cache() | |
return iteration | |
# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py | |
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: | |
# workaround https://github.com/pytorch/pytorch/issues/24139 | |
model_state_dict = model.state_dict() | |
incorrect_shapes = [] | |
for k in list(checkpoint_state_dict.keys()): | |
if k in model_state_dict: | |
if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 | |
log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") | |
continue | |
model_param = model_state_dict[k] | |
# Allow mismatch for uninitialized parameters | |
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): | |
continue | |
if not isinstance(model_param, torch.Tensor): | |
raise ValueError( | |
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." | |
) | |
shape_model = tuple(model_param.shape) | |
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) | |
if shape_model != shape_checkpoint: | |
has_observer_base_classes = ( | |
TORCH_VERSION >= (1, 8) | |
and hasattr(quantization, "ObserverBase") | |
and hasattr(quantization, "FakeQuantizeBase") | |
) | |
if has_observer_base_classes: | |
# Handle the special case of quantization per channel observers, | |
# where buffer shape mismatches are expected. | |
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: | |
# foo.bar.param_or_buffer_name -> [foo, bar] | |
key_parts = key.split(".")[:-1] | |
cur_module = model | |
for key_part in key_parts: | |
cur_module = getattr(cur_module, key_part) | |
return cur_module | |
cls_to_skip = ( | |
ObserverBase, | |
FakeQuantizeBase, | |
) | |
target_module = _get_module_for_key(model, k) | |
if isinstance(target_module, cls_to_skip): | |
# Do not remove modules with expected shape mismatches | |
# them from the state_dict loading. They have special logic | |
# in _load_from_state_dict to handle the mismatches. | |
continue | |
incorrect_shapes.append((k, shape_checkpoint, shape_model)) | |
checkpoint_state_dict.pop(k) | |
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) | |
# Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling | |
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] | |
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] | |
return _IncompatibleKeys( | |
missing_keys=missing_keys, | |
unexpected_keys=unexpected_keys, | |
incorrect_shapes=incorrect_shapes, | |
) | |