Spaces:
Build error
Build error
File size: 11,364 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import threading
from typing import List, NamedTuple, Tuple
import torch
from cosmos_predict1.utils import distributed, log, misc
from cosmos_predict1.utils.checkpointer import Checkpointer as BaseCheckpointer
from cosmos_predict1.utils.model import Model
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 11):
from torch.ao import quantization
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
elif (
TORCH_VERSION >= (1, 8)
and hasattr(torch.quantization, "FakeQuantizeBase")
and hasattr(torch.quantization, "ObserverBase")
):
from torch import quantization
from torch.quantization import FakeQuantizeBase, ObserverBase
class _IncompatibleKeys(
NamedTuple(
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
],
)
):
pass
class MultiRankCheckpointer(BaseCheckpointer):
def save(
self,
model: Model,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
grad_scaler: torch.amp.GradScaler,
iteration: int,
) -> None:
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
Args:
model (Model): The PyTorch model.
optimizer (torch.optim.Optimizer): The model optimizer.
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
iteration (int): Current iteration number.
"""
# checkpoint_file = f"iter_{iteration:09}.pt"
postfix, _, total_ema_num = model.get_ckpt_postfix()
checkpoint_file = f"iter_{iteration:09}{postfix}.pt"
save_ranks = list(range(total_ema_num))
for _rank in save_ranks:
if distributed.get_rank() == _rank:
state_dict = dict(
model=model.state_dict(),
optimizer=optimizer.state_dict(),
scheduler=scheduler.state_dict(),
grad_scaler=grad_scaler.state_dict(),
iteration=iteration,
)
state_dict = misc.to(state_dict, device="cpu")
self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
# Wait for previous saver thread to end.
if self.save_thread:
self.save_thread.join()
# Run the checkpoint saver in a separate thread.
self.save_thread = threading.Thread(
target=self._save_worker_local,
daemon=False,
args=(state_dict, checkpoint_file, distributed.get_rank()),
)
self.save_thread.start()
@misc.timer("checkpoint loading")
def load(
self,
model: Model,
optimizer: torch.optim.Optimizer | None = None,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
grad_scaler: torch.amp.GradScaler | None = None,
) -> int:
"""Load network weights and optimizer states from a checkpoint in a single process.
The priority of the checkpoint loading logic is:
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
- This is typically used for inference mode.
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
3. If none of the above, randomly initialize the model parameters and train from scratch.
Args:
model (Model): The PyTorch model.
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
Returns:
iteration (int): the iteration number to start/resume from.
"""
latest_checkpoint_file = self._read_latest_checkpoint_file()
if latest_checkpoint_file is not None:
# different from base checkpointer, this support multi-EMA
postfix, _, total_ema_num = model.get_ckpt_postfix()
latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt")
# 1. Resume training from latest_checkpoint.txt under the same name.
checkpoint_dir = self.checkpoint_dir_local
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
resume = True
else:
if self.load_path:
# 2. Load the module weights specified by config_checkpoint.path.
checkpoint_path = self.load_path
# different from base checkpointer, this support multi-EMA
postfix, _, total_ema_num = model.get_ckpt_postfix()
checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt")
resume = self.load_training_state
else:
# 3. Randomly initialize the model parameters and train from scratch.
checkpoint_path = None
resume = False
# Load checkpoint.
if checkpoint_path is not None:
self._check_checkpoint_exists(checkpoint_path)
log.info(f"Loading checkpoint (local): {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
# Load the state dicts.
log.info("- Loading the model...")
log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume))
if resume:
iteration = state_dict["iteration"]
assert optimizer and scheduler
log.info("- Loading the optimizer...")
optimizer.load_state_dict(state_dict["optimizer"])
log.info("- Loading the scheduler...")
scheduler.load_state_dict(state_dict["scheduler"])
scheduler.last_epoch = iteration
log.info("- Loading the gradient scaler...")
grad_scaler.load_state_dict(state_dict["grad_scaler"])
log.success(f"Done with loading the checkpoint (iteration {iteration}).")
else:
iteration = 0
log.success("Done with loading the checkpoint.")
else:
# Checkpoint not found and not specified. We will train everything from scratch.
iteration = 0
log.info("Training from scratch.")
torch.cuda.empty_cache()
return iteration
# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
# workaround https://github.com/pytorch/pytorch/issues/24139
model_state_dict = model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
continue
model_param = model_state_dict[k]
# Allow mismatch for uninitialized parameters
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
continue
if not isinstance(model_param, torch.Tensor):
raise ValueError(
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
)
shape_model = tuple(model_param.shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
has_observer_base_classes = (
TORCH_VERSION >= (1, 8)
and hasattr(quantization, "ObserverBase")
and hasattr(quantization, "FakeQuantizeBase")
)
if has_observer_base_classes:
# Handle the special case of quantization per channel observers,
# where buffer shape mismatches are expected.
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
# foo.bar.param_or_buffer_name -> [foo, bar]
key_parts = key.split(".")[:-1]
cur_module = model
for key_part in key_parts:
cur_module = getattr(cur_module, key_part)
return cur_module
cls_to_skip = (
ObserverBase,
FakeQuantizeBase,
)
target_module = _get_module_for_key(model, k)
if isinstance(target_module, cls_to_skip):
# Do not remove modules with expected shape mismatches
# them from the state_dict loading. They have special logic
# in _load_from_state_dict to handle the mismatches.
continue
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
# Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
return _IncompatibleKeys(
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
incorrect_shapes=incorrect_shapes,
)
|