GENC3-docker / cosmos_predict1 /utils /fsdp_optim_fix.py
roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# isort: skip_file
"""
torch 2.2 has bugs in loading optimizer states for FSDP in hybrid mode
torch impl uses state.rank and dist.rank() inconsistently
The file fix the bugs. Verified it works for hybrid mode and fullly sharded mode
Please use the `scatter_full_optim_state_dict` in the code to replace the corresponding function in torch 2.2
"""
import copy
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._debug_utils import SimpleProfiler
from torch.distributed.fsdp._optim_utils import (
_flatten_optim_state,
_FSDPState,
_get_fqn_to_fsdp_param_info,
_get_param_to_fqns,
_OptimStateKey,
_PosDimTensorInfo,
_shard_orig_param_state,
tree_map_only,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import _rekey_sharded_optim_state_dict
def _broadcast_processed_state(
fsdp_state: _FSDPState,
optim_state: Dict[str, Any],
group: Optional[dist.ProcessGroup],
) -> Dict[str, Any]:
objects: List[Any] = [None]
if fsdp_state.rank == 0:
objects[0] = tree_map_only(
torch.Tensor,
lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype),
optim_state,
)
dist.broadcast_object_list(objects, src=0, group=group)
if dist.get_rank() == 0:
return optim_state
else:
return objects[0]
def _broadcast_state(fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]) -> Any:
if dist.get_rank() == 0:
if not isinstance(state, torch.Tensor) or state.dim() == 0:
return state
tensor = state.to(fsdp_state.compute_device)
else:
if isinstance(state, torch.Tensor):
assert state.dim() == 0, (
"For non-zero ranks, a tensor state should have zero dimension, "
"but got the state with shape {state.shape()}."
)
return state
elif not isinstance(state, _PosDimTensorInfo):
return state
tensor = torch.zeros(state.shape, dtype=state.dtype, device=fsdp_state.compute_device)
dist.broadcast(tensor, src=0, group=group)
return tensor
def _flatten_optim_state_dict(
optim_state_dict: Dict[str, Any],
model: nn.Module,
use_orig_params: bool = False,
optim: Optional[torch.optim.Optimizer] = None,
rank0_only: bool = False,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Flattens the full optimizer state dict, still keying by unflattened parameter
names.
If ``use_orig_params`` is True, each rank will have all FSDP-managed
parameters but some of these parameters may be empty due to the sharding.
For a regular optim.Optimizer, states for those empty parameters will
not be initialized. So, when aggregating the FQNs across ranks, no assert
will be raised on a rank even if it does not have all the states -- it is
valid and FSDP know how to aggregate them. However, FSDP has to ignore
handling those parameters that are not managed by FSDP and do not exist on
the local rank -- it is managed by other parallelism and FSDP does not
know ho to handle/aggregate them.
Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to
flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require
all the states even if the corresponding parameters are empty. To this end,
``optim`` will be used to to get the initial state of the empty parameters.
``optim`` should only be non-None if the ``optim` is KeyedOptimizer or
NamedOptimizer.
Returns:
Dict[str, Any]: The flattened optimizer state dict.
"""
SimpleProfiler.reset()
unflat_osd = optim_state_dict
if "state" not in unflat_osd and not rank0_only:
raise ValueError('`optim_state_dict` must have the keys "state"' "to be a valid optimizer state dict")
param_to_fqns = _get_param_to_fqns(model)
fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state
# Broadcast unflat_osd without non-scalar tensor if rank0_only is True.
if rank0_only:
unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group)
# Construct the "state" part
flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}
unflat_osd_state = unflat_osd["state"]
all_state_keys = set(unflat_osd_state.keys())
for param, fqns in param_to_fqns.items():
fqn = fqns[0]
if fqn not in unflat_osd_state:
continue
all_state_keys.difference_update(fqns)
if rank0_only:
for fqn in fqns:
if not unflat_osd_state[fqn]:
continue
for state_name in unflat_osd_state[fqn].keys():
unflat_osd_state[fqn][state_name] = _broadcast_state(
fsdp_state, unflat_osd_state[fqn][state_name], group=group
)
fqn = fqns[0]
if fqn in fqn_to_fsdp_param_info:
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
if use_orig_params:
with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
flat_state = _shard_orig_param_state(
fsdp_param_info,
fqn,
unflat_osd_state[fqn],
)
else:
flat_state = _flatten_optim_state(
fsdp_param_info,
unflat_osd_state,
fqns,
)
key = _OptimStateKey(tuple(fqns), True)
# Only include non-empty states since as expected by
# `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer
# or NamedOptimizer.
if flat_state:
flat_osd_state[key] = flat_state
elif use_orig_params:
assert len(fqns) == 1, f"use_orig_params is True but there are multiple FQNs, {fqns}."
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
state = optim.state.get(param, None) # type: ignore[call-overload]
if state is not None:
flat_osd_state[key] = copy.deepcopy(state)
else:
warnings.warn(f"optim_state[{key}] is not on rank{fsdp_state.rank}.")
else:
raise RuntimeError(f"The state of {key} is empty. This should happen when " "use_orig_params=True.")
else: # do not flatten non-FSDP parameters' states
assert len(fqns) == 1
key = _OptimStateKey(tuple(fqns), False)
flat_osd_state[key] = copy.copy(unflat_osd_state[fqn])
if rank0_only:
for fqn in fqns:
if not unflat_osd_state[fqn]:
continue
for state_name, param_state in list(unflat_osd_state[fqn].items()):
if fsdp_state.rank > 0:
# Deference the tensor so that PyTorch can collect the memory.
del unflat_osd_state[fqn][state_name]
else:
# Move the tensor in the original osd back to CPU to make the
# original osd unaffected.
unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][state_name].cpu()
# Handle user-defined state, states that are not associated with parameters.
for key in all_state_keys:
user_state = unflat_osd_state[key]
if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params:
user_state = _broadcast_state(fsdp_state, user_state, group=group)
flat_osd_state[key] = copy.copy(user_state)
SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ")
# Construct the "param_groups" part -- copy as is since it will be
# rekeyed later according to the target rank's optimizer
# Only copy param_groups if it exists in unflat_osd
if "param_groups" in unflat_osd:
flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"])
return {"state": flat_osd_state, "param_groups": flat_osd_param_groups}
else:
return {"state": flat_osd_state}
def _optim_state_dict_to_load_impl(
optim_state_dict: Dict[str, Any],
model: torch.nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
full_state_dict: bool = True,
rank0_only: bool = False,
is_named_optimizer: bool = False,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
The internal API that is used by all the load optim_state_dict implementations.
Given model, optim, and the saved optim_state_dict, this API adds the FSDP
internal information and internal sharding to the optim_state_dict.
"""
if full_state_dict:
FullyShardedDataParallel._warn_optim_input(optim_input)
using_optim_input = FullyShardedDataParallel._is_using_optim_input(
optim_input,
optim,
)
else:
using_optim_input = False
assert optim_input is None and not rank0_only
use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[0]._use_orig_params
assert all(
use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model)
), "Not all FSDP modules have the same _use_orig_params value"
if rank0_only and dist.get_rank(group) > 0:
optim_state_dict = {}
sharded_osd = _flatten_optim_state_dict(
optim_state_dict,
model=model,
use_orig_params=use_orig_params,
optim=(optim if is_named_optimizer else None),
rank0_only=rank0_only,
group=group,
)
return _rekey_sharded_optim_state_dict(
sharded_osd,
model=model,
optim=optim,
optim_input=optim_input,
using_optim_input=using_optim_input,
is_named_optimizer=is_named_optimizer,
)
def scatter_full_optim_state_dict(
full_optim_state_dict: Optional[Dict[str, Any]],
model: torch.nn.Module,
optim_input: Optional[
Union[
List[Dict[str, Any]],
Iterable[torch.nn.Parameter],
]
] = None,
optim: Optional[torch.optim.Optimizer] = None,
group: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Scatters the full optimizer state dict from rank 0 to all other ranks,
returning the sharded optimizer state dict on each rank. The return
value is the same as :meth:`shard_full_optim_state_dict`, and on rank
0, the first argument should be the return value of
:meth:`full_optim_state_dict`.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> model, optim = ...
>>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0
>>> # Define new model with possibly different world size
>>> new_model, new_optim, new_group = ...
>>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
>>> new_optim.load_state_dict(sharded_osd)
.. note:: Both :meth:`shard_full_optim_state_dict` and
:meth:`scatter_full_optim_state_dict` may be used to get the
sharded optimizer state dict to load. Assuming that the full
optimizer state dict resides in CPU memory, the former requires
each rank to have the full dict in CPU memory, where each rank
individually shards the dict without any communication, while the
latter requires only rank 0 to have the full dict in CPU memory,
where rank 0 moves each shard to GPU memory (for NCCL) and
communicates it to ranks appropriately. Hence, the former has
higher aggregate CPU memory cost, while the latter has higher
communication cost.
Args:
full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state
dict corresponding to the unflattened parameters and holding
the full non-sharded optimizer state if on rank 0; the argument
is ignored on nonzero ranks.
model (torch.nn.Module): Root module (which may or may not be a
:class:`FullyShardedDataParallel` instance) whose parameters
correspond to the optimizer state in ``full_optim_state_dict``.
optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
Input passed into the optimizer representing either a
:class:`list` of parameter groups or an iterable of parameters;
if ``None``, then this method assumes the input was
``model.parameters()``. This argument is deprecated, and there
is no need to pass it in anymore. (Default: ``None``)
optim (Optional[torch.optim.Optimizer]): Optimizer that will load
the state dict returned by this method. This is the preferred
argument to use over ``optim_input``. (Default: ``None``)
group (dist.ProcessGroup): Model's process group or ``None`` if
using the default process group. (Default: ``None``)
Returns:
Dict[str, Any]: The full optimizer state dict now remapped to
flattened parameters instead of unflattened parameters and
restricted to only include this rank's part of the optimizer state.
"""
FullyShardedDataParallel._warn_legacy_optim_state_dict("scatter_full_optim_state_dict", "optim_state_dict_to_load")
return _optim_state_dict_to_load_impl(
optim_state_dict=full_optim_state_dict,
model=model,
optim_input=optim_input,
optim=optim,
full_state_dict=True,
rank0_only=True,
is_named_optimizer=False,
group=group,
)