# NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch from ._ops import ops from typing import List, Tuple, Union from torch import Tensor from torch.optim.optimizer import Optimizer, ParamsT class AdamATan2(Optimizer): def __init__( self, params: ParamsT, lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), weight_decay: float = 1e-2, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) super().__init__(params, defaults) def _init_group( self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps ): for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: raise RuntimeError("AdamW does not support sparse gradients") grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. # This is because kernel launches are costly on CUDA and XLA. state["step"] = torch.zeros((), dtype=torch.float32, device=p.device) # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like( p, memory_format=torch.preserve_format ) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like( p, memory_format=torch.preserve_format ) exp_avgs.append(state["exp_avg"]) exp_avg_sqs.append(state["exp_avg_sq"]) state_steps.append(state["step"]) def step(self): """Perform a single optimization step.""" self._cuda_graph_capture_health_check() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] state_steps = [] beta1, beta2 = group["betas"] self._init_group( group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps ) _adam_atan2( params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps, beta1=beta1, beta2=beta2, lr=group["lr"], weight_decay=group["weight_decay"], ) def _adam_atan2( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], state_steps: List[Tensor], beta1: float, beta2: float, lr: float, weight_decay: float, ) -> None: if not params: return # We only support scalar lr. assert not isinstance(lr, Tensor) grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( [params, grads, exp_avgs, exp_avg_sqs, state_steps] ) for (device, _), ( ( device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_state_steps, ), _, ) in grouped_tensors.items(): torch._foreach_add_(device_state_steps, 1) ops.adam_atan2_cuda_impl_( device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_state_steps, lr, beta1, beta2, weight_decay, )