Spaces:
Running
on
A100
Running
on
A100
# Copyright (c) 2025 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
# LICENSE is in incl_licenses directory. | |
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from collections import OrderedDict, defaultdict | |
from copy import deepcopy | |
from itertools import chain | |
from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union | |
import qoptim_cuda | |
import torch | |
from torch import Tensor | |
from torch.optim.optimizer import Optimizer | |
from typing_extensions import ParamSpec, Self, TypeAlias | |
StateDict: TypeAlias = Dict[str, Any] | |
convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2} | |
class CoatAdamW(Optimizer): | |
def __init__( | |
self, | |
qargs, | |
params, | |
lr: float = 1e-3, | |
betas: Tuple[float, float] = (0.9, 0.999), | |
eps: float = 1e-8, | |
weight_decay: float = 1e-2, | |
amsgrad: bool = False, | |
*, | |
fused: Optional[bool] = None, | |
): | |
self.qargs = qargs | |
assert self.qargs.first_order_expansion == self.qargs.second_order_expansion | |
if not 0.0 <= lr: | |
raise ValueError(f"Invalid learning rate: {lr}") | |
if not 0.0 <= eps: | |
raise ValueError(f"Invalid epsilon value: {eps}") | |
if not 0.0 <= betas[0] < 1.0: | |
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") | |
if not 0.0 <= betas[1] < 1.0: | |
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") | |
if not 0.0 <= weight_decay: | |
raise ValueError(f"Invalid weight_decay value: {weight_decay}") | |
defaults = dict( | |
lr=lr, | |
betas=betas, | |
eps=eps, | |
weight_decay=weight_decay, | |
amsgrad=amsgrad, | |
fused=fused, | |
) | |
super().__init__(params, defaults) | |
def __setstate__(self, state): | |
super().__setstate__(state) | |
for group in self.param_groups: | |
group.setdefault("amsgrad", False) | |
fused = group.setdefault("fused", None) | |
for p in group["params"]: | |
p_state = self.state.get(p, []) | |
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): | |
step_val = float(p_state["step"]) | |
p_state["step"] = torch.tensor(step_val, dtype=torch.float32) | |
def _init_group( | |
self, | |
group, | |
params_with_grad, | |
grads, | |
amsgrad, | |
use_expansion, | |
exp_avgs, | |
scale_exp_avgs, | |
expand_exp_avgs, | |
sqrt_minmax_exp_avgs, | |
exp_avg_sqs, | |
scale_exp_avg_sqs, | |
expand_exp_avg_sqs, | |
sqrt_minmax_exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
): | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
params_with_grad.append(p) | |
if p.grad.is_sparse: | |
raise RuntimeError("AdamW does not support sparse gradients") | |
grads.append(p.grad) | |
state = self.state[p] | |
# print(f'Param shape: {p.shape}', file=open('debug.txt', 'a')) | |
# print(f'Param shape: {p.shape}, {p.device}') | |
# State initialization | |
if len(state) == 0: | |
# This is because kernel launches are costly on CUDA and XLA. | |
state["step"] = torch.tensor(0.0) | |
# Should be torch.float8_e4m3fn | |
first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit] | |
second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit] | |
scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size | |
# Exponential moving average of gradient values | |
state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format) | |
state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) | |
if use_expansion: | |
state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) | |
state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) | |
# Exponential moving average of squared gradient values | |
state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format) | |
state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) | |
if use_expansion: | |
state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) | |
state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) | |
if amsgrad: | |
# Maintains max of all exp. moving avg. of sq. grad. values | |
state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format) | |
exp_avgs.append(state["exp_avg"]) | |
scale_exp_avgs.append(state["scale_exp_avg"]) | |
if use_expansion: | |
expand_exp_avgs.append(state["expand_exp_avg"]) | |
sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"]) | |
exp_avg_sqs.append(state["exp_avg_sq"]) | |
scale_exp_avg_sqs.append(state["scale_exp_avg_sq"]) | |
if use_expansion: | |
expand_exp_avg_sqs.append(state["expand_exp_avg_sq"]) | |
sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"]) | |
if group["amsgrad"]: | |
max_exp_avg_sqs.append(state["max_exp_avg_sq"]) | |
state_steps.append(state["step"]) | |
def load_state_dict(self, state_dict: StateDict) -> None: | |
r"""Loads the optimizer state. | |
Args: | |
state_dict (dict): optimizer state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
# shallow copy, to be consistent with module API | |
state_dict = state_dict.copy() | |
for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): | |
hook_result = pre_hook(self, state_dict) | |
if hook_result is not None: | |
state_dict = hook_result | |
# Validate the state_dict | |
groups = self.param_groups | |
# Deepcopy as we write into saved_groups later to update state | |
saved_groups = deepcopy(state_dict["param_groups"]) | |
if len(groups) != len(saved_groups): | |
raise ValueError("loaded state dict has a different number of " "parameter groups") | |
param_lens = (len(g["params"]) for g in groups) | |
saved_lens = (len(g["params"]) for g in saved_groups) | |
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): | |
raise ValueError( | |
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" | |
) | |
# Update the state | |
id_map = dict( | |
zip( | |
chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups) | |
) | |
) | |
def _cast(param, value, param_id=None, param_groups=None, key=None): | |
r"""Make a deep copy of value, casting all tensors to device of param.""" | |
if isinstance(value, torch.Tensor): | |
return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key) | |
elif isinstance(value, dict): | |
return { | |
k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items() | |
} | |
elif isinstance(value, Iterable): | |
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] | |
else: | |
return value | |
# Copy state assigned to params (and cast tensors to appropriate types). | |
# State that is not assigned to params is copied as is (needed for | |
# backward compatibility). | |
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) | |
for k, v in state_dict["state"].items(): | |
if k in id_map: | |
param = id_map[k] | |
state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) | |
else: | |
state[k] = v | |
# Update parameter groups, setting their 'params' value | |
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: | |
new_group["params"] = group["params"] | |
return new_group | |
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] | |
self.__setstate__({"state": state, "param_groups": param_groups}) | |
for post_hook in self._optimizer_load_state_dict_post_hooks.values(): | |
post_hook(self) | |
def _process_value_according_to_param_policy( | |
param: torch.Tensor, | |
value: torch.Tensor, | |
param_id: int, | |
param_groups: List[Dict[Any, Any]], | |
key: Hashable = None, | |
) -> torch.Tensor: | |
# Floating-point types are a bit special here. They are the only ones | |
# that are assumed to always match the type of params. | |
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 | |
# UNLESS fused or capturable, see note [special device hosting for step] | |
fused = False | |
capturable = False | |
assert param_groups is not None | |
for pg in param_groups: | |
if param_id in pg["params"]: | |
fused = pg["fused"] if "fused" in pg else False | |
capturable = pg["capturable"] if "capturable" in pg else False | |
break | |
if key == "step": | |
if capturable or fused: | |
return value.to(dtype=torch.float32, device=param.device) | |
else: | |
return value | |
else: | |
assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32] | |
return value.to(device=param.device) # do not cast optimizer states | |
# if param.is_floating_point(): | |
# return value.to(dtype=param.dtype, device=param.device) | |
# else: | |
# return value.to(device=param.device) | |
def step(self, closure=None): | |
"""Perform a single optimization step. | |
Args: | |
closure (Callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
self._cuda_graph_capture_health_check() | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
params_with_grad = [] | |
grads = [] | |
exp_avgs = [] | |
scale_exp_avgs = [] | |
expand_exp_avgs = [] | |
sqrt_minmax_exp_avgs = [] | |
exp_avg_sqs = [] | |
scale_exp_avg_sqs = [] | |
expand_exp_avg_sqs = [] | |
sqrt_minmax_exp_avg_sqs = [] | |
max_exp_avg_sqs = [] | |
state_steps = [] | |
amsgrad = group["amsgrad"] | |
use_expansion = self.qargs.first_order_expansion in ["expansion", "true"] | |
beta1, beta2 = group["betas"] | |
self._init_group( | |
group, | |
params_with_grad, | |
grads, | |
amsgrad, | |
use_expansion, | |
exp_avgs, | |
scale_exp_avgs, | |
expand_exp_avgs, | |
sqrt_minmax_exp_avgs, | |
exp_avg_sqs, | |
scale_exp_avg_sqs, | |
expand_exp_avg_sqs, | |
sqrt_minmax_exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
) | |
Coatadamw( | |
self.qargs, | |
params_with_grad, | |
grads, | |
exp_avgs, | |
scale_exp_avgs, | |
expand_exp_avgs, | |
sqrt_minmax_exp_avgs, | |
exp_avg_sqs, | |
scale_exp_avg_sqs, | |
expand_exp_avg_sqs, | |
sqrt_minmax_exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
amsgrad=amsgrad, | |
use_expansion=use_expansion, | |
beta1=beta1, | |
beta2=beta2, | |
lr=group["lr"], | |
weight_decay=group["weight_decay"], | |
eps=group["eps"], | |
qgroup_size=self.qargs.qgroup_size, | |
expand_min=self.qargs.expand_min, | |
fused=group["fused"], | |
grad_scale=getattr(self, "grad_scale", None), | |
found_inf=getattr(self, "found_inf", None), | |
) | |
return loss | |
def Coatadamw( | |
qargs, | |
params: List[Tensor], | |
grads: List[Tensor], | |
exp_avgs: List[Tensor], | |
scale_exp_avgs: List[Tensor], | |
expand_exp_avgs: List[Tensor], | |
sqrt_minmax_exp_avgs: List[Tensor], | |
exp_avg_sqs: List[Tensor], | |
scale_exp_avg_sqs: List[Tensor], | |
expand_exp_avg_sqs: List[Tensor], | |
sqrt_minmax_exp_avg_sqs: List[Tensor], | |
max_exp_avg_sqs: List[Tensor], | |
state_steps: List[Tensor], | |
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 | |
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim | |
fused: Optional[bool] = None, | |
grad_scale: Optional[Tensor] = None, | |
found_inf: Optional[Tensor] = None, | |
*, | |
amsgrad: bool, | |
use_expansion: bool, | |
beta1: float, | |
beta2: float, | |
lr: Union[float, Tensor], | |
weight_decay: float, | |
eps: float, | |
qgroup_size: int, | |
expand_min: int, | |
): | |
r"""Functional API that performs AdamW algorithm computation. | |
See :class:`~torch.optim.AdamW` for details. | |
""" | |
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): | |
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") | |
func = _single_tensor_Coatadamw | |
func( | |
qargs, | |
params, | |
grads, | |
exp_avgs, | |
scale_exp_avgs, | |
expand_exp_avgs, | |
sqrt_minmax_exp_avgs, | |
exp_avg_sqs, | |
scale_exp_avg_sqs, | |
expand_exp_avg_sqs, | |
sqrt_minmax_exp_avg_sqs, | |
max_exp_avg_sqs, | |
state_steps, | |
amsgrad=amsgrad, | |
use_expansion=use_expansion, | |
beta1=beta1, | |
beta2=beta2, | |
lr=lr, | |
weight_decay=weight_decay, | |
eps=eps, | |
qgroup_size=qgroup_size, | |
expand_min=expand_min, | |
grad_scale=grad_scale, | |
found_inf=found_inf, | |
) | |
def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference | |
if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): | |
return x.sqrt() | |
else: | |
return sqrt(x) | |
def _single_tensor_Coatadamw( | |
qargs, | |
params: List[Tensor], | |
grads: List[Tensor], | |
exp_avgs: List[Tensor], | |
scale_exp_avgs: List[Tensor], | |
expand_exp_avgs: List[Tensor], | |
sqrt_minmax_exp_avgs: List[Tensor], | |
exp_avg_sqs: List[Tensor], | |
scale_exp_avg_sqs: List[Tensor], | |
expand_exp_avg_sqs: List[Tensor], | |
sqrt_minmax_exp_avg_sqs: List[Tensor], | |
max_exp_avg_sqs: List[Tensor], | |
state_steps: List[Tensor], | |
grad_scale: Optional[Tensor], | |
found_inf: Optional[Tensor], | |
*, | |
amsgrad: bool, | |
use_expansion: bool, | |
beta1: float, | |
beta2: float, | |
lr: Union[Tensor, float], | |
weight_decay: float, | |
eps: float, | |
qgroup_size: int, | |
expand_min: int, | |
): | |
assert grad_scale is None and found_inf is None | |
if torch.jit.is_scripting(): | |
# this assert is due to JIT being dumb and not realizing that the ops below | |
# have overloads to handle both float and Tensor lrs, so we just assert it's | |
# a float since most people using JIT are using floats | |
assert isinstance(lr, float) | |
for i, param in enumerate(params): | |
grad = grads[i] | |
# First order | |
exp_avg = exp_avgs[i] | |
scale_exp_avg = scale_exp_avgs[i] | |
# Second order | |
exp_avg_sq = exp_avg_sqs[i] | |
scale_exp_avg_sq = scale_exp_avg_sqs[i] | |
step_t = state_steps[i] | |
# print(len(exp_avg.unique()), len(exp_avg_sq.unique())) | |
# print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a')) | |
# update step | |
step_t += 1 | |
step = int(step_t.item()) | |
# Perform Optimizer Step | |
if use_expansion: | |
expand_exp_avg = expand_exp_avgs[i] | |
sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i] | |
expand_exp_avg_sq = expand_exp_avg_sqs[i] | |
sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i] | |
qoptim_cuda.fp8_adamw_expand_step( | |
param, | |
grad, | |
exp_avg, | |
scale_exp_avg, | |
expand_exp_avg, | |
sqrt_minmax_exp_avg, | |
exp_avg_sq, | |
scale_exp_avg_sq, | |
expand_exp_avg_sq, | |
sqrt_minmax_exp_avg_sq, | |
beta1, | |
beta2, | |
lr, | |
weight_decay, | |
eps, | |
step, | |
qgroup_size, | |
expand_min, | |
) | |
else: | |
qoptim_cuda.fp8_adamw_step( | |
param, | |
grad, | |
exp_avg, | |
scale_exp_avg, | |
exp_avg_sq, | |
scale_exp_avg_sq, | |
beta1, | |
beta2, | |
lr, | |
weight_decay, | |
eps, | |
step, | |
qgroup_size, | |
) | |