iamwyldecat commited on
Commit
8535e80
·
1 Parent(s): cf531ba

chore: initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README.md +33 -0
  3. build.toml +23 -0
  4. build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py +5 -0
  5. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +9 -0
  6. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  7. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  8. build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +458 -0
  9. build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py +5 -0
  10. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +9 -0
  11. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  12. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  13. build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +458 -0
  14. build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py +5 -0
  15. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +9 -0
  16. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  17. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  18. build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +458 -0
  19. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py +5 -0
  20. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +9 -0
  21. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so +3 -0
  22. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so +3 -0
  23. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  24. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  25. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +458 -0
  26. build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py +5 -0
  27. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +9 -0
  28. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  29. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  30. build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +458 -0
  31. build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py +5 -0
  32. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +9 -0
  33. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  34. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  35. build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +458 -0
  36. build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py +5 -0
  37. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +9 -0
  38. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  39. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  40. build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +458 -0
  41. build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py +5 -0
  42. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +9 -0
  43. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  44. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  45. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +458 -0
  46. build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py +5 -0
  47. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +9 -0
  48. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so +3 -0
  49. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so +3 -0
  50. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +458 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.pdf filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - kernel
4
+ ---
5
+
6
+ # Optimizer
7
+
8
+ Optimizer is a python package that provides:
9
+ - PyTorch implementation of recent optimizer algorithms
10
+ - with support for parallelism techniques for efficient large-scale training.
11
+
12
+ ### Currently implemented
13
+ - [Parallel Muon with FSDP2](./docs/muon/parallel_muon.pdf)
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ import torch
19
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20
+ from kernels import get_kernel
21
+
22
+ optimizer = get_kernel("motif-technologies/optimizer")
23
+
24
+ model = None # your model here
25
+ fsdp_model = FSDP(model)
26
+
27
+ optim = optimizer.Muon(
28
+ fsdp_model.parameters(),
29
+ lr=0.01,
30
+ momentum=0.9,
31
+ weight_decay=1e-4,
32
+ )
33
+ ```
build.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "optimizer"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.activation]
12
+ backend = "rocm"
13
+ src = [
14
+ "optimizer/dummy.cu",
15
+ ]
16
+ depends = [ "torch" ]
17
+
18
+ [kernel.activation_cuda]
19
+ backend = "cuda"
20
+ src = [
21
+ "optimizer/dummy.cu",
22
+ ]
23
+ depends = [ "torch" ]
build/torch26-cxx11-cu118-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66ca698639fff584999fe65f8f10cc4436c197829e936be2741bf53db685caa0
3
+ size 1787272
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8325d12959ef4f31b6c6340eca29176f5077abeaa10f3a6651db55ccf3c634f
3
+ size 1787272
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch26-cxx11-cu124-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e89cd7d514bfe92598684ae3cfc2d35ac2d021340846e09c0b6c880c3d55bfa0
3
+ size 1820136
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cbffc2cf8039069831a57afb8e2f64fa684f1a44bec79bb4b72dbb57d9ac607
3
+ size 1824224
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch26-cxx11-cu126-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f5dce62d3038e879e688fffa9bbc70f3e82db20b2e7ae3ba09040e0319acb71
3
+ size 1820136
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58162f994df84868dbf62ae70e39d3c14e3390fc827f152eece83dfae7f51503
3
+ size 1824224
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614121529.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2f60369ba2bd0a0f84e053d857d37496137ff476dc21561f211b1fa39651990
3
+ size 1749784
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614123843.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4d790535f99b7b362a966e802a547654f31749f5f28a0207493870927f1d8d2
3
+ size 1749784
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b440dd9a60711a498010068e91d0ad013cd0b8ac732c16b5d1d17e5d4ec0f9b4
3
+ size 1749784
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f50ea9cab62a5bd06d886516d3917e4490e65aa9addd1cbb84fc81c6f9a9d5b1
3
+ size 1749744
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch26-cxx98-cu118-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f8e7d78ed9a095b882cf764fd9c80a0b0810fb961ba9e8545656fc4cb0b0d7
3
+ size 1787200
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:002dab6441bcad54ab4e7c064b5806acfd45170eb33cfa059745ba6e0c349607
3
+ size 1787192
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch26-cxx98-cu124-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab2379d932e40d10bee55f032bd16d2e4d9c1920bc5500628006f8a0eb8abd39
3
+ size 1824192
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f499350bb19eca6c3da1bb72e46023834b8411ce00730854273b588b2cd9206
3
+ size 1824184
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch26-cxx98-cu126-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c3282a321487a6faa532afe43bc1298731983c50e2a1acdff5480ff6e4df34e
3
+ size 1824192
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5b49ed642e1c320da3932377033ad90031124f4ec24b2d1c95fd976ff28346c
3
+ size 1824184
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch27-cxx11-cu118-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de82486a39ded94bfe7eeaa862459944a93e284fd0d919329979bb67db3c367f
3
+ size 1787376
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ac9027c4a93801e9f19f1e9e94a9ed33b27e92c72797053c3de55e2a6fbb41d
3
+ size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss
build/torch27-cxx11-cu126-x86_64-linux/optimizer/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .muon import Muon
2
+
3
+ __all__ = [
4
+ "Muon",
5
+ ]
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _optimizer_b4b3752_dirty
3
+ ops = torch.ops._optimizer_b4b3752_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_optimizer_b4b3752_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_20250614125054.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb02d3818a89c819a5a12d066ce56da0ebc4f3da491cb045ae380c5b9319e592
3
+ size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_optimizer_b4b3752_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b425a7fd854402508da5af17fa88f305753a09474686d6ec7afe540b3c5c082e
3
+ size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor
7
+
8
+
9
+ # TODO leave original url and consider LICENSE
10
+ # This code snippet is a modified version adapted from the following GitHub repository:
11
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
12
+ def _zeropower_via_newtonschulz5(G, steps):
13
+ """
14
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
15
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
16
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
17
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
18
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
19
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
20
+ performance at all relative to UV^T, where USV^T = G is the SVD.
21
+ """
22
+ assert len(G.shape) == 2
23
+ a, b, c = (3.4445, -4.7750, 2.0315)
24
+ X = G # no manual typecast
25
+ if G.size(0) > G.size(1):
26
+ X = X.T
27
+ # Ensure spectral norm is at most 1
28
+ X = X / (X.norm() + 1e-7)
29
+ X = X.bfloat16()
30
+ # Perform the NS iterations
31
+ for _ in range(steps):
32
+ A = X @ X.T
33
+ # B = (
34
+ # b * A + c * A @ A
35
+ # ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
36
+ B = torch.addmm(A, A, A, alpha=c, beta=b)
37
+ # X = a * X + B @ X
38
+ X = torch.addmm(X, B, X, alpha=1.0, beta=a)
39
+
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ return X.to(G.dtype)
43
+
44
+
45
+ @dataclass
46
+ class _muon_state:
47
+ # TODO: use Optional
48
+ worker_rank: int | None = None
49
+ gathered_grad: torch.Tensor | None = None
50
+ computed_u: torch.Tensor | None = None
51
+ scattered_u: torch.Tensor | None = None
52
+ gather_event: torch.cuda.Event | None = None
53
+ compute_event: torch.cuda.Event | None = None
54
+
55
+
56
+ def _gather(p, state, rank, comm_stream):
57
+ g = p.grad
58
+ mesh = g.device_mesh
59
+
60
+ if rank == state.worker_rank:
61
+ gather_list = [torch.empty_like(g.to_local()) for _ in range(mesh.mesh.numel())]
62
+ else:
63
+ gather_list = None
64
+
65
+ with torch.cuda.stream(comm_stream):
66
+ torch.distributed.gather(
67
+ g.to_local(),
68
+ dst=state.worker_rank,
69
+ gather_list=gather_list,
70
+ group=mesh.get_group(),
71
+ )
72
+ if rank == state.worker_rank:
73
+ # TODO: Consider ,,,
74
+ if state.gathered_grad is not None:
75
+ raise RuntimeError(
76
+ "Gather event already exists, which should not happen."
77
+ )
78
+ state.gathered_grad = torch.cat(gather_list, dim=0)
79
+ state.gather_event = torch.cuda.Event()
80
+ state.gather_event.record()
81
+ else:
82
+ state.gathered_grad = None
83
+ state.gather_event = None
84
+
85
+
86
+ def _compute_u(state, steps, rank, compute_stream):
87
+ with torch.cuda.stream(compute_stream):
88
+ if rank == state.worker_rank:
89
+ if state.gather_event is None:
90
+ raise RuntimeError("Gather event must be set before compute.")
91
+ compute_stream.wait_event(state.gather_event)
92
+ u = _zeropower_via_newtonschulz5(state.gathered_grad, steps)
93
+ state.computed_u = u
94
+ state.compute_event = torch.cuda.Event()
95
+ state.compute_event.record()
96
+ else:
97
+ state.computed_u = None
98
+ state.compute_event = None
99
+
100
+
101
+ def _scatter(p, state, rank, comm_stream):
102
+ u = state.computed_u
103
+ mesh = p.device_mesh
104
+
105
+ with torch.cuda.stream(comm_stream):
106
+ if rank == state.worker_rank:
107
+ if state.compute_event is None:
108
+ raise RuntimeError("Compute event must be set before scatter.")
109
+ comm_stream.wait_event(state.compute_event)
110
+ scatter_list = list(torch.split(u, p.size(0) // mesh.mesh.numel(), dim=0))
111
+ else:
112
+ scatter_list = None
113
+
114
+ u = torch.empty_like(p.to_local())
115
+ torch.distributed.scatter(
116
+ u,
117
+ scatter_list=scatter_list,
118
+ src=state.worker_rank,
119
+ group=mesh.get_group(),
120
+ )
121
+ u = DTensor.from_local(
122
+ u,
123
+ placements=p.placements,
124
+ device_mesh=mesh,
125
+ )
126
+
127
+ state.scattered_u = u
128
+
129
+
130
+ class Muon(torch.optim.Optimizer):
131
+ """
132
+ Muon - MomentUm Orthogonalized by Newton-schulz
133
+
134
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
135
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
136
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
137
+ the advantage that it can be stably run in bfloat16 on the GPU.
138
+
139
+ Some warnings:
140
+ - We believe this optimizer is unlikely to work well for training with small batch size.
141
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
142
+
143
+ Arguments:
144
+ muon_params: The parameters to be optimized by Muon.
145
+ lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
146
+ momentum: The momentum used by the internal SGD. (0.95 is a good default)
147
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
148
+ ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
149
+ adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
150
+ {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
151
+ adamw_lr: The learning rate for the internal AdamW.
152
+ adamw_betas: The betas for the internal AdamW.
153
+ adamw_eps: The epsilon for the internal AdamW.
154
+ adamw_wd: The weight decay for the internal AdamW.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ model,
160
+ is_muon_func,
161
+ lr=1e-3,
162
+ momentum=0.95,
163
+ nesterov=True,
164
+ ns_steps=5,
165
+ adamw_wd=0.1,
166
+ adamw_betas=(0.9, 0.95),
167
+ adamw_eps=1e-8,
168
+ debug=False,
169
+ ):
170
+ defaults = dict(
171
+ lr=lr,
172
+ wd=adamw_wd,
173
+ momentum=momentum,
174
+ nesterov=nesterov,
175
+ ns_steps=ns_steps,
176
+ adamw_betas=adamw_betas,
177
+ adamw_eps=adamw_eps,
178
+ )
179
+
180
+ super().__init__(model.parameters(), defaults)
181
+ self.is_muon_func = is_muon_func
182
+ self.model = model
183
+
184
+ if not dist.is_initialized():
185
+ raise RuntimeError(
186
+ "Muon optimizer requires distributed training to be initialized."
187
+ )
188
+
189
+ self.rank = dist.get_rank()
190
+
191
+ self.comm_stream = torch.cuda.Stream()
192
+ self.compute_stream = torch.cuda.Stream()
193
+ self.debug = debug
194
+
195
+ def __setstate__(self, state):
196
+ # Sort parameters into those for which we will use Muon, and those for which we will not
197
+ super().__setstate__(state)
198
+ for name, p in self.model.named_parameters():
199
+ if self.is_muon_func(p, name):
200
+ # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
201
+ assert p.ndim == 2, p.ndim
202
+ self.state[p]["use_muon"] = True
203
+ self.state[p]["orig_shape"] = p.shape
204
+ else:
205
+ # Do not use Muon for parameters in adamw_params
206
+ self.state[p]["use_muon"] = False
207
+
208
+ def _calc_flops(self, G, steps):
209
+ assert len(G.shape) == 2
210
+ M, N = G.shape
211
+ if M > N:
212
+ M, N = N, M
213
+
214
+ return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3)
215
+
216
+ def adjust_lr_for_muon(self, lr, param_shape):
217
+ A, B = param_shape[:2]
218
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
219
+ # as describted in the paper
220
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
221
+ adjusted_lr = lr * adjusted_ratio
222
+ return adjusted_lr
223
+
224
+ def init_state_and_assign_params(self, params, group):
225
+ param_to_state = {}
226
+ param_to_flops = {}
227
+
228
+ total_flops = 0
229
+ for p in params:
230
+ g = p.grad
231
+ if g is None:
232
+ continue
233
+ assert g.ndim == 2, "Muon only supports 2D parameters."
234
+
235
+ flops = self._calc_flops(g, group["ns_steps"])
236
+ param_to_flops[id(p)] = flops
237
+ total_flops += flops
238
+
239
+ if self.debug:
240
+ print(f"Total TFLOPs for Muon: {total_flops / 1e12:.2f} TFLOPs", flush=True)
241
+
242
+ ordered_params = sorted(
243
+ params, key=lambda p: param_to_flops[id(p)], reverse=True
244
+ )
245
+
246
+ round_robin = 0
247
+ mesh = None
248
+ for p in ordered_params:
249
+ if mesh is None:
250
+ mesh = p.device_mesh
251
+ if mesh.ndim != 1:
252
+ raise NotImplementedError(
253
+ "Muon requires a 1D mesh for distributed training yet."
254
+ )
255
+ elif mesh != p.device_mesh:
256
+ raise ValueError("All parameters must be on the same mesh.")
257
+
258
+ param_to_state[id(p)] = _muon_state()
259
+ param_to_state[id(p)].worker_rank = mesh.mesh[round_robin].item()
260
+
261
+ round_robin = (round_robin + 1) % mesh.mesh.numel()
262
+
263
+ return param_to_state, ordered_params
264
+
265
+ def base(self, params, group, lr, wd, momentum):
266
+ # generate weight updates in distributed fashion
267
+ for p in params:
268
+ g = p.grad
269
+ if g is None:
270
+ continue
271
+ if g.ndim > 2:
272
+ g = g.view(g.size(0), -1)
273
+ assert g is not None
274
+
275
+ # calc update
276
+ state = self.state[p]
277
+ if "momentum_buffer" not in state:
278
+ state["momentum_buffer"] = torch.zeros_like(g)
279
+ buf = state["momentum_buffer"]
280
+ buf.mul_(momentum).add_(g)
281
+ if group["nesterov"]:
282
+ g = g.add(buf, alpha=momentum)
283
+ else:
284
+ g = buf
285
+
286
+ u = _zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
287
+
288
+ # scale update
289
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
290
+
291
+ # apply weight decay
292
+ p.data.mul_(1 - lr * wd)
293
+
294
+ # apply update
295
+ p.data.add_(u, alpha=-adjusted_lr)
296
+
297
+ def _update_g(self, p, g, group, momentum):
298
+ # calc update
299
+ state = self.state[p]
300
+ if "momentum_buffer" not in state:
301
+ state["momentum_buffer"] = torch.zeros_like(g)
302
+ buf = state["momentum_buffer"]
303
+ buf.mul_(momentum).add_(g)
304
+ if group["nesterov"]:
305
+ g = g.add(buf, alpha=momentum)
306
+ else:
307
+ g = buf
308
+ return g
309
+
310
+ def _update_p(self, p, u, lr, wd):
311
+ # scale update
312
+ adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
313
+ # apply weight decay
314
+ p.data.mul_(1 - lr * wd)
315
+ # apply update
316
+ p.data.add_(u, alpha=-adjusted_lr)
317
+
318
+ def parallel(self, params, group, lr, wd, momentum):
319
+ """
320
+ Perform a parallel optimization step using Muon.
321
+ """
322
+
323
+ for p in params:
324
+ g = p.grad
325
+ if g is None:
326
+ continue
327
+ if g.ndim > 2:
328
+ g = g.view(g.size(0), -1)
329
+
330
+ # Update g in the local rank
331
+ g = self._update_g(
332
+ p,
333
+ g,
334
+ group,
335
+ momentum=momentum,
336
+ )
337
+ p.grad = g
338
+
339
+ param_to_state, ordered_params = self.init_state_and_assign_params(
340
+ params, group
341
+ )
342
+
343
+ def enqueue_gathers(start_idx, chunk_size):
344
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
345
+ state = param_to_state[id(p)]
346
+ _gather(p, state, self.rank, self.comm_stream)
347
+
348
+ def enqueue_computes(start_idx, chunk_size):
349
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
350
+ state = param_to_state[id(p)]
351
+ _compute_u(state, group["ns_steps"], self.rank, self.compute_stream)
352
+
353
+ def enqueue_scatters(start_idx, chunk_size):
354
+ for p in ordered_params[start_idx : start_idx + chunk_size]:
355
+ state = param_to_state[id(p)]
356
+ _scatter(p, state, self.rank, self.comm_stream)
357
+
358
+ chunk_size = params[0].device_mesh.mesh.numel()
359
+
360
+ # Wait grad update
361
+ self.comm_stream.wait_stream(torch.cuda.current_stream())
362
+
363
+ enqueue_gathers(0, chunk_size)
364
+ for i in range(0, len(params) + chunk_size - 1, chunk_size):
365
+ enqueue_computes(i, chunk_size)
366
+ enqueue_gathers(i + chunk_size, chunk_size)
367
+ enqueue_scatters(i, chunk_size)
368
+
369
+ torch.cuda.current_stream().wait_stream(self.comm_stream)
370
+
371
+ for p in params:
372
+ g = p.grad
373
+ if g is None:
374
+ continue
375
+
376
+ # Update p with sharded u
377
+ state = param_to_state[id(p)]
378
+ self._update_p(
379
+ p,
380
+ state.scattered_u,
381
+ lr=lr,
382
+ wd=wd,
383
+ )
384
+
385
+ def step(self, closure=None):
386
+ """Perform a single optimization step.
387
+
388
+ Args:
389
+ closure (Callable, optional): A closure that reevaluates the model
390
+ and returns the loss.
391
+ """
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ ############################
399
+ # Muon #
400
+ ############################
401
+
402
+ params = [p for p in group["params"] if self.state[p]["use_muon"]]
403
+ lr = group["lr"]
404
+ wd = group["wd"]
405
+ momentum = group["momentum"]
406
+
407
+ if isinstance(params[0].data, DTensor):
408
+ self.parallel(
409
+ params,
410
+ group,
411
+ lr=lr,
412
+ wd=wd,
413
+ momentum=momentum,
414
+ )
415
+ else:
416
+ self.base(
417
+ params,
418
+ group,
419
+ lr=lr,
420
+ wd=wd,
421
+ momentum=momentum,
422
+ )
423
+
424
+ ############################
425
+ # AdamW backup #
426
+ ############################
427
+
428
+ params = [p for p in group["params"] if not self.state[p]["use_muon"]]
429
+ lr = group["lr"]
430
+ beta1, beta2 = group["adamw_betas"]
431
+ eps = group["adamw_eps"]
432
+ weight_decay = group["wd"]
433
+
434
+ for p in params:
435
+ g = p.grad
436
+ if g is None:
437
+ continue
438
+ state = self.state[p]
439
+ if "step" not in state:
440
+ state["step"] = 0
441
+ state["moment1"] = torch.zeros_like(g)
442
+ state["moment2"] = torch.zeros_like(g)
443
+ state["step"] += 1
444
+ step = state["step"]
445
+ buf1 = state["moment1"]
446
+ buf2 = state["moment2"]
447
+ buf1.lerp_(g, 1 - beta1)
448
+ buf2.lerp_(g.square(), 1 - beta2)
449
+
450
+ g = buf1 / (eps + buf2.sqrt())
451
+
452
+ bias_correction1 = 1 - beta1**step
453
+ bias_correction2 = 1 - beta2**step
454
+ scale = bias_correction1 / bias_correction2**0.5
455
+ p.data.mul_(1 - lr * weight_decay)
456
+ p.data.add_(g, alpha=-lr / scale)
457
+
458
+ return loss