# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. # Copyright 2024 NVIDIA CORPORATION & AFFILIATES # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict, defaultdict from copy import deepcopy from itertools import chain from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union import qoptim_cuda import torch from torch import Tensor from torch.optim.optimizer import Optimizer from typing_extensions import ParamSpec, Self, TypeAlias StateDict: TypeAlias = Dict[str, Any] convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2} class CoatAdamW(Optimizer): def __init__( self, qargs, params, lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, amsgrad: bool = False, *, fused: Optional[bool] = None, ): self.qargs = qargs assert self.qargs.first_order_expansion == self.qargs.second_order_expansion if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, fused=fused, ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("amsgrad", False) fused = group.setdefault("fused", None) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): step_val = float(p_state["step"]) p_state["step"] = torch.tensor(step_val, dtype=torch.float32) def _init_group( self, group, params_with_grad, grads, amsgrad, use_expansion, exp_avgs, scale_exp_avgs, expand_exp_avgs, sqrt_minmax_exp_avgs, exp_avg_sqs, scale_exp_avg_sqs, expand_exp_avg_sqs, sqrt_minmax_exp_avg_sqs, max_exp_avg_sqs, state_steps, ): for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: raise RuntimeError("AdamW does not support sparse gradients") grads.append(p.grad) state = self.state[p] # print(f'Param shape: {p.shape}', file=open('debug.txt', 'a')) # print(f'Param shape: {p.shape}, {p.device}') # State initialization if len(state) == 0: # This is because kernel launches are costly on CUDA and XLA. state["step"] = torch.tensor(0.0) # Should be torch.float8_e4m3fn first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit] second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit] scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format) state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) if use_expansion: state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format) state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) if use_expansion: state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format) exp_avgs.append(state["exp_avg"]) scale_exp_avgs.append(state["scale_exp_avg"]) if use_expansion: expand_exp_avgs.append(state["expand_exp_avg"]) sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"]) exp_avg_sqs.append(state["exp_avg_sq"]) scale_exp_avg_sqs.append(state["scale_exp_avg_sq"]) if use_expansion: expand_exp_avg_sqs.append(state["expand_exp_avg_sq"]) sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"]) if group["amsgrad"]: max_exp_avg_sqs.append(state["max_exp_avg_sq"]) state_steps.append(state["step"]) @torch._disable_dynamo def load_state_dict(self, state_dict: StateDict) -> None: r"""Loads the optimizer state. Args: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ # shallow copy, to be consistent with module API state_dict = state_dict.copy() for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): hook_result = pre_hook(self, state_dict) if hook_result is not None: state_dict = hook_result # Validate the state_dict groups = self.param_groups # Deepcopy as we write into saved_groups later to update state saved_groups = deepcopy(state_dict["param_groups"]) if len(groups) != len(saved_groups): raise ValueError("loaded state dict has a different number of " "parameter groups") param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): raise ValueError( "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" ) # Update the state id_map = dict( zip( chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups) ) ) def _cast(param, value, param_id=None, param_groups=None, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key) elif isinstance(value, dict): return { k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items() } elif isinstance(value, Iterable): return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] else: return value # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) else: state[k] = v # Update parameter groups, setting their 'params' value def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: new_group["params"] = group["params"] return new_group param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({"state": state, "param_groups": param_groups}) for post_hook in self._optimizer_load_state_dict_post_hooks.values(): post_hook(self) @staticmethod def _process_value_according_to_param_policy( param: torch.Tensor, value: torch.Tensor, param_id: int, param_groups: List[Dict[Any, Any]], key: Hashable = None, ) -> torch.Tensor: # Floating-point types are a bit special here. They are the only ones # that are assumed to always match the type of params. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 # UNLESS fused or capturable, see note [special device hosting for step] fused = False capturable = False assert param_groups is not None for pg in param_groups: if param_id in pg["params"]: fused = pg["fused"] if "fused" in pg else False capturable = pg["capturable"] if "capturable" in pg else False break if key == "step": if capturable or fused: return value.to(dtype=torch.float32, device=param.device) else: return value else: assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32] return value.to(device=param.device) # do not cast optimizer states # if param.is_floating_point(): # return value.to(dtype=param.dtype, device=param.device) # else: # return value.to(device=param.device) @torch.no_grad() def step(self, closure=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. """ self._cuda_graph_capture_health_check() loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] scale_exp_avgs = [] expand_exp_avgs = [] sqrt_minmax_exp_avgs = [] exp_avg_sqs = [] scale_exp_avg_sqs = [] expand_exp_avg_sqs = [] sqrt_minmax_exp_avg_sqs = [] max_exp_avg_sqs = [] state_steps = [] amsgrad = group["amsgrad"] use_expansion = self.qargs.first_order_expansion in ["expansion", "true"] beta1, beta2 = group["betas"] self._init_group( group, params_with_grad, grads, amsgrad, use_expansion, exp_avgs, scale_exp_avgs, expand_exp_avgs, sqrt_minmax_exp_avgs, exp_avg_sqs, scale_exp_avg_sqs, expand_exp_avg_sqs, sqrt_minmax_exp_avg_sqs, max_exp_avg_sqs, state_steps, ) Coatadamw( self.qargs, params_with_grad, grads, exp_avgs, scale_exp_avgs, expand_exp_avgs, sqrt_minmax_exp_avgs, exp_avg_sqs, scale_exp_avg_sqs, expand_exp_avg_sqs, sqrt_minmax_exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad=amsgrad, use_expansion=use_expansion, beta1=beta1, beta2=beta2, lr=group["lr"], weight_decay=group["weight_decay"], eps=group["eps"], qgroup_size=self.qargs.qgroup_size, expand_min=self.qargs.expand_min, fused=group["fused"], grad_scale=getattr(self, "grad_scale", None), found_inf=getattr(self, "found_inf", None), ) return loss def Coatadamw( qargs, params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], scale_exp_avgs: List[Tensor], expand_exp_avgs: List[Tensor], sqrt_minmax_exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], scale_exp_avg_sqs: List[Tensor], expand_exp_avg_sqs: List[Tensor], sqrt_minmax_exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim fused: Optional[bool] = None, grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None, *, amsgrad: bool, use_expansion: bool, beta1: float, beta2: float, lr: Union[float, Tensor], weight_decay: float, eps: float, qgroup_size: int, expand_min: int, ): r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. """ if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") func = _single_tensor_Coatadamw func( qargs, params, grads, exp_avgs, scale_exp_avgs, expand_exp_avgs, sqrt_minmax_exp_avgs, exp_avg_sqs, scale_exp_avg_sqs, expand_exp_avg_sqs, sqrt_minmax_exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad=amsgrad, use_expansion=use_expansion, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay, eps=eps, qgroup_size=qgroup_size, expand_min=expand_min, grad_scale=grad_scale, found_inf=found_inf, ) def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): return x.sqrt() else: return sqrt(x) def _single_tensor_Coatadamw( qargs, params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], scale_exp_avgs: List[Tensor], expand_exp_avgs: List[Tensor], sqrt_minmax_exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], scale_exp_avg_sqs: List[Tensor], expand_exp_avg_sqs: List[Tensor], sqrt_minmax_exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, amsgrad: bool, use_expansion: bool, beta1: float, beta2: float, lr: Union[Tensor, float], weight_decay: float, eps: float, qgroup_size: int, expand_min: int, ): assert grad_scale is None and found_inf is None if torch.jit.is_scripting(): # this assert is due to JIT being dumb and not realizing that the ops below # have overloads to handle both float and Tensor lrs, so we just assert it's # a float since most people using JIT are using floats assert isinstance(lr, float) for i, param in enumerate(params): grad = grads[i] # First order exp_avg = exp_avgs[i] scale_exp_avg = scale_exp_avgs[i] # Second order exp_avg_sq = exp_avg_sqs[i] scale_exp_avg_sq = scale_exp_avg_sqs[i] step_t = state_steps[i] # print(len(exp_avg.unique()), len(exp_avg_sq.unique())) # print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a')) # update step step_t += 1 step = int(step_t.item()) # Perform Optimizer Step if use_expansion: expand_exp_avg = expand_exp_avgs[i] sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i] expand_exp_avg_sq = expand_exp_avg_sqs[i] sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i] qoptim_cuda.fp8_adamw_expand_step( param, grad, exp_avg, scale_exp_avg, expand_exp_avg, sqrt_minmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, expand_exp_avg_sq, sqrt_minmax_exp_avg_sq, beta1, beta2, lr, weight_decay, eps, step, qgroup_size, expand_min, ) else: qoptim_cuda.fp8_adamw_step( param, grad, exp_avg, scale_exp_avg, exp_avg_sq, scale_exp_avg_sq, beta1, beta2, lr, weight_decay, eps, step, qgroup_size, )