diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c77e427cdd07868332dbb6004998da6078f66eac 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.pdf filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e24a75af8cfd719a94a499a644188b3164b2d1cb --- /dev/null +++ b/README.md @@ -0,0 +1,33 @@ +--- +tags: +- kernel +--- + +# Optimizer + +Optimizer is a python package that provides: +- PyTorch implementation of recent optimizer algorithms +- with support for parallelism techniques for efficient large-scale training. + +### Currently implemented +- [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf) + +## Usage + +```python +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from kernels import get_kernel + +optimizer = get_kernel("motif-technologies/optimizer") + +model = None # your model here +fsdp_model = FSDP(model) + +optim = optimizer.Muon( + fsdp_model.parameters(), + lr=0.01, + momentum=0.9, + weight_decay=1e-4, +) +``` diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..b80854db0a67cdde4e5c3dcb8d95f18704812383 --- /dev/null +++ b/build.toml @@ -0,0 +1,23 @@ +[general] +name = "optimizer" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + +[kernel.activation] +backend = "rocm" +src = [ + "optimizer/dummy.cu", +] +depends = [ "torch" ] + +[kernel.activation_cuda] +backend = "cuda" +src = [ + "optimizer/dummy.cu", +] +depends = [ "torch" ] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6f1afc3bc9c26549c621ee8396cfd9b6d632228e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66ca698639fff584999fe65f8f10cc4436c197829e936be2741bf53db685caa0 +size 1787272 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3ad8df1e7879102f18c8f3ecefdcd4a710867734 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8325d12959ef4f31b6c6340eca29176f5077abeaa10f3a6651db55ccf3c634f +size 1787272 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3be3c1fe4294649f4aad6e9c2baed7dd62d26788 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e89cd7d514bfe92598684ae3cfc2d35ac2d021340846e09c0b6c880c3d55bfa0 +size 1820136 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b011a8836fa8a0b3ad74ec14e29a6284a0742be2 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cbffc2cf8039069831a57afb8e2f64fa684f1a44bec79bb4b72dbb57d9ac607 +size 1824224 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..ec7e75e7b673420b7ff82464fd0f66d086797be8 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f5dce62d3038e879e688fffa9bbc70f3e82db20b2e7ae3ba09040e0319acb71 +size 1820136 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..27870794663e4e380d26a8c438668dc9b1501547 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58162f994df84868dbf62ae70e39d3c14e3390fc827f152eece83dfae7f51503 +size 1824224 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..52812e8a8729a06d9a416a541cdb9dacdbd18bde --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d2f60369ba2bd0a0f84e053d857d37496137ff476dc21561f211b1fa39651990 +size 1749784 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f920a59fe9b333a3a502408b21ab45b5946283ba --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4d790535f99b7b362a966e802a547654f31749f5f28a0207493870927f1d8d2 +size 1749784 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..70bf31624ca0641277491c69bb148281e987b9ce --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b440dd9a60711a498010068e91d0ad013cd0b8ac732c16b5d1d17e5d4ec0f9b4 +size 1749784 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..68ebbf08ff0cd761f4f6817ede6d806330daa380 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f50ea9cab62a5bd06d886516d3917e4490e65aa9addd1cbb84fc81c6f9a9d5b1 +size 1749744 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..97c729e3fbaacccbb4de3648ecec3bdd3e3df48c --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8f8e7d78ed9a095b882cf764fd9c80a0b0810fb961ba9e8545656fc4cb0b0d7 +size 1787200 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..171b656d3cca9f3e371b6235f8c8b53738289a14 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:002dab6441bcad54ab4e7c064b5806acfd45170eb33cfa059745ba6e0c349607 +size 1787192 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e28a927785764651e3ce2e76f737478ce74b93ed --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab2379d932e40d10bee55f032bd16d2e4d9c1920bc5500628006f8a0eb8abd39 +size 1824192 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..baae25dbc1cd71cee383d366d50ef444612b9029 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f499350bb19eca6c3da1bb72e46023834b8411ce00730854273b588b2cd9206 +size 1824184 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..12a767e52b5d38684926563a9c8969fc50229dd8 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c3282a321487a6faa532afe43bc1298731983c50e2a1acdff5480ff6e4df34e +size 1824192 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8ad904ccfaa8794b609dd947fe9dabcca1040ca9 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5b49ed642e1c320da3932377033ad90031124f4ec24b2d1c95fd976ff28346c +size 1824184 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..cf52a2d146c6a6a3f90e6171553aa70fa5a04359 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de82486a39ded94bfe7eeaa862459944a93e284fd0d919329979bb67db3c367f +size 1787376 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c2749235508977b018762b027b746c1fec58c251 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ac9027c4a93801e9f19f1e9e94a9ed33b27e92c72797053c3de55e2a6fbb41d +size 1787368 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fe455a191950613d486ea79c18f82b0fbbad6f3a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb02d3818a89c819a5a12d066ce56da0ebc4f3da491cb045ae380c5b9319e592 +size 1824256 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..da27633d09f7236707f64a4ddc048aea95bc0ee8 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b425a7fd854402508da5af17fa88f305753a09474686d6ec7afe540b3c5c082e +size 1824256 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..79c295dcc99970e4a8d8e08ad782238ac619ebd6 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f4bb02fd9fc62e272efc673aa7cb96f363e6c1d617515c93ae6708db3feaa8e +size 1883352 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c0dda129458a8c19caef554baa11de9106a6d370 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1150b2d74d708ef2f7d8a24049c49b9ba6e2d8f5d5ce5ae88a611e4d555fe659 +size 1883352 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c86e7a83ac210b884df736d7099245396cfb404 Binary files /dev/null and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-310.pyc differ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bebae2bd877757895050b1e2d832c9dab5a0a2d2 Binary files /dev/null and b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-310.pyc differ diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..1bdf816efdc028c0b2cee6954f1c3207ffb21400 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _optimizer_b4b3752_dirty +ops = torch.ops._optimizer_b4b3752_dirty + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_optimizer_b4b3752_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..52f817670f67373218bd36e98856a9192e1c59f9 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0796f06f2de4e26247141c21c1e8dafc3d268073a3eb2c8f2ef810cf588c2746 +size 1749688 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c3262a8b5b00ca39b8a37f1a1ca51a955d90f1f5 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57144d037f5db2441940be53c25ff198c5b3ec11bc5edac809bb208434e8d53d +size 1749688 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..24c19f92b157c4cb10fd6723e5e13488cee571ab --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e362d290566b6187aedf433241e853cf60f311b69e49b35d9b8f70892fbb57f6 +size 1749688 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..01410e2d256e65f553c8b5522c011367253547a5 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3489677edaccff3afade9c51427a89ac1edf283d092cdd3bc39e06d75c231f1 +size 1749648 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py new file mode 100755 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/docs/muon/balanced.png b/docs/muon/balanced.png new file mode 100644 index 0000000000000000000000000000000000000000..2076978a5a0149d598b419bfc45c508405dca0df --- /dev/null +++ b/docs/muon/balanced.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9933e2cd5490513593dd6cf1c5c4f18b7f33fd6e6b11c696784269c2bb78055b +size 98003 diff --git a/docs/muon/distributed_muon.png b/docs/muon/distributed_muon.png new file mode 100644 index 0000000000000000000000000000000000000000..26544c9e035afae48d1b32cd6ae729c600a47f33 --- /dev/null +++ b/docs/muon/distributed_muon.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31caea472991fd24a7934bf211b5adcbf154b5295bfe364bba5b603851c2cfae +size 407912 diff --git a/docs/muon/distributed_muon_execution.png b/docs/muon/distributed_muon_execution.png new file mode 100644 index 0000000000000000000000000000000000000000..824c728b78c73ca0d5b70a169ed2e5e50a59946c --- /dev/null +++ b/docs/muon/distributed_muon_execution.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:72ab4d8076f1e182900d71636dd22c32b20bface38890cef72a0c94c496d5f02 +size 57140 diff --git a/docs/muon/imbalance.png b/docs/muon/imbalance.png new file mode 100644 index 0000000000000000000000000000000000000000..d63f0a034912195910cfac8a49f0533ac99968b1 --- /dev/null +++ b/docs/muon/imbalance.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c71d5faed05d46b2269fefa3b6bea6791d7bf51744f47aa4bb8c311eda1b27ff +size 56528 diff --git a/docs/muon/main.tex b/docs/muon/main.tex new file mode 100644 index 0000000000000000000000000000000000000000..bd9a86c98f424fd742a0a7068e799ed37eaedc3b --- /dev/null +++ b/docs/muon/main.tex @@ -0,0 +1,142 @@ +\documentclass{article} +\usepackage{graphicx} +\usepackage{hyperref} +\usepackage{amsmath} +\usepackage{caption} +\usepackage{tgtermes} +\usepackage{float} +\usepackage[a4paper, margin=1in]{geometry} +\usepackage{booktabs} +\usepackage{algorithm} +\usepackage{algorithmicx} +\usepackage{algpseudocode} +\date{} + +\begin{document} + +{\LARGE \bfseries Parallelize Muon with FSDP2 \par} +\vspace{1em} % 제목 아래 간격 조정 + +\section*{Motivation} + +\begin{figure}[H] + \centering + \includegraphics[width=0.8\textwidth]{distributed_muon.png} + \caption*{Distributed Muon by Moonlight} +\end{figure} + +While a distributed version of Muon is available, it has the drawback of redundant computations across GPUs. + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{distributed_muon_execution.png} + \caption*{Execution timeline of Distributed Muon} +\end{figure} + +\begin{itemize} + \item \texttt{C[i]} : Compute Newton-Schulz(G) for i-th gradient + \item \texttt{AG[i]} : AllGather i-th gradient + \item \texttt{G[i]} : Gather i-th gradient + \item \texttt{SC[i]} : Scatter i-th gradient +\end{itemize} +\clearpage +\section*{Algorithm} + +\subsection*{Parallel Muon} + +\begin{algorithm} +\caption{Parallel Muon} +\textbf{Require:} DP partitioned gradients $\mathbf{gG}$, DP partitioned Momentum $\mathbf{m}$, DP partitioned parameters $\mathbf{p}$, momentum $\mu$, local rank $\mathbf{r}$ +\begin{algorithmic}[1] +\State \texttt{// Apply momentum to $\mathbf{g}$ using local partitioned momentum $\mathbf{m}$} +\State $\mathbf{g'} \gets \text{update\_with\_momentum}(\mathbf{g}, \mathbf{m}, \mu)$ +\State \texttt{// Schedule $\mathbf{g'}$ to rank $\mathbf{R}$} +\State $\mathbf{R} \gets \text{schedule}(\mathbf{g'}, \text{dp\_group})$ +\State \texttt{// Gather $\mathbf{g'}$ across DP into a full matrix $\mathbf{G}$ to rank $\mathbf{R}$} +\State $\mathbf{G} \gets \text{gather}(\mathbf{g'}, \text{dp\_group}, \text{dst=}\mathbf{R})$ +\State \texttt{// Calculate Newton-Schulz only in $\mathbf{R}$} +\If{$\mathbf{r}$ == $\mathbf{R}$} + \State $\mathbf{u} \gets \text{Newton-Schulz}(\mathbf{G})$ +\Else + \State $\mathbf{u} \gets None$ +\EndIf + +\State \texttt{// Scatter a full matrix $\mathbf{u}$ across DP} +\State $\mathbf{u'} \gets \text{scatter}(\mathbf{u},\text{dp\_group},\text{src=}\mathbf{R})$ +\State \texttt{// Apply DP partitioned $\mathbf{u'}$ to $\mathbf{p}$} +\State $\mathbf{p'} \gets \text{apply\_update}(\mathbf{p}, \mathbf{u'})$ +\State \textbf{return $\mathbf{p'}$} +\end{algorithmic} +\end{algorithm} + +We eliminate redundant computation by assigning each parameter to a specific GPU. + +However, without proper scheduling, this optimization can lead to poor GPU utilization. In particular, although redundant computation is avoided by assigning each parameter to a specific rank, it causes idle time—since all other ranks must wait for the scatter communication to complete before proceeding. + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{naive_execution.png} + \caption*{Execution timeline of Parallel Muon} +\end{figure} + +\subsection*{Scheduling Sub-Operations} + +We can schedule the whole sub-operations as follows, due to the following reasons: +\begin{itemize} + \item There are no dependencies between parameters. + \item GPUs can execute computation and communication concurrently. +\end{itemize} + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{pipelined.png} + \caption*{Execution timeline of re-scheduled Parallel Muon} +\end{figure} + +We define the chunk size $C$ as the number of GPUs and schedule each sub-operation in batches of size $C$. This scheduling allows each GPU to continue computation even while waiting for collective communication to complete. + +\textbf{[Algorithm]} (To be written) +\clearpage +\subsection*{Load Balancing} + +If parameters in a chunk have imbalanced computation loads, idle bubbles may occur. \\ +To mitigate this, we apply load balancing based on per-parameter FLOPs. + +\vspace{1em} +\textbf{Imbalanced (Round Robin)} + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{imbalance.png} +\end{figure} + +\textbf{After Load Balancing} + +\begin{figure}[H] + \centering + \includegraphics[width=1.0\textwidth]{balanced.png} +\end{figure} + +\section*{Implementation} + +The full implementation is available in \texttt{optimizer/torch-ext/optimizer/muon.py}. +To enable concurrent computation and communication, we use separate compute and communication streams (\texttt{torch.cuda.Stream}) and use \texttt{torch.cuda.Event} to synchronize between sub-operations. + +Thanks to the simplicity of \texttt{torch.DTensor} and \texttt{torch.distributed}, the implementation remains straightforward and low in complexity. + +\section*{Evaluation} +We evaluated the performance using \href{https://huggingface.co/Motif-Technologies/Motif-2.6B}{Motif 2.6B}, achieving 151 TFLOPS per GPU during the optimizer step. + +\begin{table}[H] + \centering + \begin{tabular}{@{}lllll@{}} + \toprule + Model & TFLOPs for Muon & GPUs & Elapsed time & TFLOPS/GPU \\ + \midrule + Motif 2.6B & 847.45 & 4xMI250 (8 devices) & 1.4 s & 151 \\ + \bottomrule + \end{tabular} +\end{table} +Based on the breakdown, 7\% of the time is attributed to updating sharded gradients and parameters, 78\% to GEMM operations, and the remaining 15\% to non-overlapped communication overhead. + +\end{document} \ No newline at end of file diff --git a/docs/muon/naive_execution.png b/docs/muon/naive_execution.png new file mode 100644 index 0000000000000000000000000000000000000000..e8f3c4ce721cda02eb95f569c58739d36008b525 --- /dev/null +++ b/docs/muon/naive_execution.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eaacd3625f33cee9735ed0d96b95f98c696dfc771976be970a38c991e2ce84ab +size 42729 diff --git a/docs/muon/parallel_muon.pdf b/docs/muon/parallel_muon.pdf new file mode 100644 index 0000000000000000000000000000000000000000..31397f765a5356ac426213ca512e5ddacbf3f524 --- /dev/null +++ b/docs/muon/parallel_muon.pdf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0a27499300118f9aa736ab8bde61796640a779afcfa0e885b55b29b833a4272 +size 654653 diff --git a/docs/muon/parallel_muon.png b/docs/muon/parallel_muon.png new file mode 100644 index 0000000000000000000000000000000000000000..002032fee3e7115c210061d6feabd98dbd7eeff8 --- /dev/null +++ b/docs/muon/parallel_muon.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89572b94f05e3cd0915e01e56f0c8ac8070a6ac53f3ac3791447daa48325f9b0 +size 126934 diff --git a/docs/muon/pipelined.png b/docs/muon/pipelined.png new file mode 100644 index 0000000000000000000000000000000000000000..7e3d51f98c8f2e501704298c6ec48dca08203884 --- /dev/null +++ b/docs/muon/pipelined.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1f8043cc58e7d8d9da5694ad7bccd1b9fe0210349b9aa9a62652a97f75cf097 +size 64316 diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..368754a84e467fe6ba68962628649fc9ab6121cc --- /dev/null +++ b/flake.lock @@ -0,0 +1,167 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1747046372, + "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1748598786, + "owner": "huggingface", + "repo": "hf-nix", + "rev": "6ca679441494139fde1f2355691ddb5dc8170269", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1749822059, + "narHash": "sha256-zype8KSqESZUIQpsY6sbf4f9pPxM/Zwem+KuH5LeHFk=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "96abd968baa5fa16217413050fa7372d5db3baa5", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1747820358, + "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", + "owner": "danieldk", + "repo": "nixpkgs", + "rev": "d3c1681180717528068082103bf323147de6ab0b", + "type": "github" + }, + "original": { + "owner": "danieldk", + "ref": "cudatoolkit-12.9-kernel-builder", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..a4c72b81489bcc67dff0d5a2857d765a0a74c5b0 --- /dev/null +++ b/flake.nix @@ -0,0 +1,11 @@ +{ + description = "Flake for Torch kernel extension"; + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + outputs = { self, kernel-builder, }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/optimizer/dummy.cu b/optimizer/dummy.cu new file mode 100644 index 0000000000000000000000000000000000000000..9a9780b635e46a946e7c836cd390c24e41da3385 --- /dev/null +++ b/optimizer/dummy.cu @@ -0,0 +1,6 @@ + +namespace { +__global__ void dummy() { + // This kernel does nothing but serves as a placeholder +} +} diff --git a/torch-ext/optimizer/__init__.py b/torch-ext/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..239c7a65f8293e7d0df28f05fce645af56d628c0 --- /dev/null +++ b/torch-ext/optimizer/__init__.py @@ -0,0 +1,5 @@ +from .muon import Muon + +__all__ = [ + "Muon", +] diff --git a/torch-ext/optimizer/muon.py b/torch-ext/optimizer/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ddc145fe8eef20d208d536c01185a8a3a3a907 --- /dev/null +++ b/torch-ext/optimizer/muon.py @@ -0,0 +1,458 @@ +import math +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from torch.distributed._tensor import DTensor + + +# TODO leave original url and consider LICENSE +# This code snippet is a modified version adapted from the following GitHub repository: +# https://github.com/KellerJordan/Muon/blob/master/muon.py +def _zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G # no manual typecast + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + X = X.bfloat16() + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + # B = ( + # b * A + c * A @ A + # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = torch.addmm(A, A, A, alpha=c, beta=b) + # X = a * X + B @ X + X = torch.addmm(X, B, X, alpha=1.0, beta=a) + + if G.size(0) > G.size(1): + X = X.T + return X.to(G.dtype) + + +@dataclass +class _muon_state: + # TODO: use Optional + worker_rank: int | None = None + gathered_grad: torch.Tensor | None = None + computed_u: torch.Tensor | None = None + scattered_u: torch.Tensor | None = None + gather_event: torch.cuda.Event | None = None + compute_event: torch.cuda.Event | None = None + + +def _gather(p, state, rank, comm_stream): + g = p.grad + mesh = g.device_mesh + + if rank == state.worker_rank: + gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())] + else: + gather_list = None + + with torch.cuda.stream(comm_stream): + torch.distributed.gather( + g.to_local(), + dst=state.worker_rank, + gather_list=gather_list, + group=mesh.get_group(), + ) + if rank == state.worker_rank: + # TODO: Consider ,,, + if state.gathered_grad is not None: + raise RuntimeError( + "Gather event already exists, which should not happen." + ) + state.gathered_grad = torch.cat(gather_list, dim=0) + state.gather_event = torch.cuda.Event() + state.gather_event.record() + else: + state.gathered_grad = None + state.gather_event = None + + +def _compute_u(state, steps, rank, compute_stream): + with torch.cuda.stream(compute_stream): + if rank == state.worker_rank: + if state.gather_event is None: + raise RuntimeError("Gather event must be set before compute.") + compute_stream.wait_event(state.gather_event) + u = _zeropower_via_newtonschulz5(state.gathered_grad, steps) + state.computed_u = u + state.compute_event = torch.cuda.Event() + state.compute_event.record() + else: + state.computed_u = None + state.compute_event = None + + +def _scatter(p, state, rank, comm_stream): + u = state.computed_u + mesh = p.device_mesh + + with torch.cuda.stream(comm_stream): + if rank == state.worker_rank: + if state.compute_event is None: + raise RuntimeError("Compute event must be set before scatter.") + comm_stream.wait_event(state.compute_event) + scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0)) + else: + scatter_list = None + + u = torch.empty_like(p.to_local()) + torch.distributed.scatter( + u, + scatter_list=scatter_list, + src=state.worker_rank, + group=mesh.get_group(), + ) + u = DTensor.from_local( + u, + placements=p.placements, + device_mesh=mesh, + ) + + state.scattered_u = u + + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + muon_params: The parameters to be optimized by Muon. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are + {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + + def __init__( + self, + model, + is_muon_func, + lr=1e-3, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_wd=0.1, + adamw_betas=(0.9, 0.95), + adamw_eps=1e-8, + debug=False, + ): + defaults = dict( + lr=lr, + wd=adamw_wd, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + ) + + super().__init__(model.parameters(), defaults) + self.is_muon_func = is_muon_func + self.model = model + + if not dist.is_initialized(): + raise RuntimeError( + "Muon optimizer requires distributed training to be initialized." + ) + + self.rank = dist.get_rank() + + self.comm_stream = torch.cuda.Stream() + self.compute_stream = torch.cuda.Stream() + self.debug = debug + + def __setstate__(self, state): + # Sort parameters into those for which we will use Muon, and those for which we will not + super().__setstate__(state) + for name, p in self.model.named_parameters(): + if self.is_muon_func(p, name): + # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer + assert p.ndim == 2, p.ndim + self.state[p]["use_muon"] = True + self.state[p]["orig_shape"] = p.shape + else: + # Do not use Muon for parameters in adamw_params + self.state[p]["use_muon"] = False + + def _calc_flops(self, G, steps): + assert len(G.shape) == 2 + M, N = G.shape + if M > N: + M, N = N, M + + return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) + + def adjust_lr_for_muon(self, lr, param_shape): + A, B = param_shape[:2] + # We adjust the learning rate and weight decay based on the size of the parameter matrix + # as describted in the paper + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + + def init_state_and_assign_params(self, params, group): + param_to_state = {} + param_to_flops = {} + + total_flops = 0 + for p in params: + g = p.grad + if g is None: + continue + assert g.ndim == 2, "Muon only supports 2D parameters." + + flops = self._calc_flops(g, group["ns_steps"]) + param_to_flops[id(p)] = flops + total_flops += flops + + if self.debug: + print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True) + + ordered_params = sorted( + params, key=lambda p: param_to_flops[id(p)], reverse=True + ) + + round_robin = 0 + mesh = None + for p in ordered_params: + if mesh is None: + mesh = p.device_mesh + if mesh.ndim != 1: + raise NotImplementedError( + "Muon requires a 1D mesh for distributed training yet." + ) + elif mesh != p.device_mesh: + raise ValueError("All parameters must be on the same mesh.") + + param_to_state[id(p)] = _muon_state() + param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item() + + round_robin = (round_robin + 1) % mesh.mesh.numel() + + return param_to_state, ordered_params + + def base(self, params, group, lr, wd, momentum): + # generate weight updates in distributed fashion + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + assert g is not None + + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + + u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) + + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # apply weight decay + p.data.mul_(1 - lr * wd) + + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def _update_g(self, p, g, group, momentum): + # calc update + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if group["nesterov"]: + g = g.add(buf, alpha=momentum) + else: + g = buf + return g + + def _update_p(self, p, u, lr, wd): + # scale update + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + # apply weight decay + p.data.mul_(1 - lr * wd) + # apply update + p.data.add_(u, alpha=-adjusted_lr) + + def parallel(self, params, group, lr, wd, momentum): + """ + Perform a parallel optimization step using Muon. + """ + + for p in params: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + + # Update g in the local rank + g = self._update_g( + p, + g, + group, + momentum=momentum, + ) + p.grad = g + + param_to_state, ordered_params = self.init_state_and_assign_params( + params, group + ) + + def enqueue_gathers(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _gather(p, state, self.rank, self.comm_stream) + + def enqueue_computes(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _compute_u(state, group["ns_steps"], self.rank, self.compute_stream) + + def enqueue_scatters(start_idx, chunk_size): + for p in ordered_params[start_idx : start_idx + chunk_size]: + state = param_to_state[id(p)] + _scatter(p, state, self.rank, self.comm_stream) + + chunk_size = params[0].device_mesh.mesh.numel() + + # Wait grad update + self.comm_stream.wait_stream(torch.cuda.current_stream()) + + enqueue_gathers(0, chunk_size) + for i in range(0, len(params) + chunk_size - 1, chunk_size): + enqueue_computes(i, chunk_size) + enqueue_gathers(i + chunk_size, chunk_size) + enqueue_scatters(i, chunk_size) + + torch.cuda.current_stream().wait_stream(self.comm_stream) + + for p in params: + g = p.grad + if g is None: + continue + + # Update p with sharded u + state = param_to_state[id(p)] + self._update_p( + p, + state.scattered_u, + lr=lr, + wd=wd, + ) + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + ############################ + # Muon # + ############################ + + params = [p for p in group["params"] if self.state[p]["use_muon"]] + lr = group["lr"] + wd = group["wd"] + momentum = group["momentum"] + + if isinstance(params[0].data, DTensor): + self.parallel( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + else: + self.base( + params, + group, + lr=lr, + wd=wd, + momentum=momentum, + ) + + ############################ + # AdamW backup # + ############################ + + params = [p for p in group["params"] if not self.state[p]["use_muon"]] + lr = group["lr"] + beta1, beta2 = group["adamw_betas"] + eps = group["adamw_eps"] + weight_decay = group["wd"] + + for p in params: + g = p.grad + if g is None: + continue + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["moment1"] = torch.zeros_like(g) + state["moment2"] = torch.zeros_like(g) + state["step"] += 1 + step = state["step"] + buf1 = state["moment1"] + buf2 = state["moment2"] + buf1.lerp_(g, 1 - beta1) + buf2.lerp_(g.square(), 1 - beta2) + + g = buf1 / (eps + buf2.sqrt()) + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * weight_decay) + p.data.add_(g, alpha=-lr / scale) + + return loss diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02811c4d30ff4d3777fc9b13c33d93f5ea2a3d4a --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,11 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // Activation ops + ops.def("dummy() -> ()"); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..9974e804609a7056679df4961195898c27697d68 --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +void dummy();