iamwyldecat commited on
Commit
02ac540
Β·
1 Parent(s): 64757cb

refactor(muon): change argument adam_wd to weight_decay and handle params' type

Browse files
Files changed (36) hide show
  1. build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  2. build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  3. build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +52 -21
  4. build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
  5. build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  6. build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +52 -21
  7. build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  8. build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  9. build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +52 -21
  10. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
  11. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  12. build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +52 -21
  13. build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  14. build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  15. build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +52 -21
  16. build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
  17. build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  18. build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +52 -21
  19. build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  20. build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  21. build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +52 -21
  22. build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
  23. build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  24. build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +52 -21
  25. build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
  26. build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  27. build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +52 -21
  28. build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
  29. build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  30. build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +52 -21
  31. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc +0 -0
  32. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc +0 -0
  33. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
  34. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} +1 -1
  35. build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +52 -21
  36. torch-ext/optimizer/muon.py +52 -21
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c77e5647b6056bfaee25050cca7948c40859db0a88fa4fcf40b67a85c947d8c
3
  size 1787272
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1f5df341112d93e43c0801e285abd66e79bfbe399d228f8be09ff26ece7421b
3
  size 1787272
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:94ea66089cc8d9eda72b017733a9e05e4fee5a2f04c50658b690d2c19f0d3068
3
  size 1824224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2921aa2aa2587e261dc9ca4e5f60303b0d1c9a305d1584918a8c56b6dc79ebfb
3
  size 1824224
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:46e01e1d957ada2d485b30cd60bc3ef7230b8857dffc59f2e7924339761ec577
3
  size 1824224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a93530e6981fdac23236dd7e3657c5b47513cda4accec78293234ce5f233400b
3
  size 1824224
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a825a0cd31d8c1b91aa9db4b24248d7fc0a506615f625a385b40e6002025c7dd
3
  size 1749744
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caa40905ac8f209fecccae42c6892c3766ad5c7069382e60d2339e73da6ee7d6
3
  size 1749744
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:579e9ddf66a4f17ead9232c2f32e6327fe6a3f16dd235e2e73e6cb282de1797e
3
  size 1787192
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6919551ed599e7e0dc1a750d1972bdb31605f57583b3617054cb70dd40d54d26
3
  size 1787192
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:beacb4ba2d56463b6d444875728b3462cb3ff6c1449e3c9693cd665bfbbbbb73
3
  size 1824184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f07cc2637669130fc9e209cb2c4358caba1c4c2d5837a108043b073d7897c3a7
3
  size 1824184
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9b04b011803d328d8dcd2edcf4c3840ddbb1bb2f093464c208f0ba2faf4f16bc
3
  size 1824184
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b9ef8fa2dd4d80cb3c1c3c2a72b99e0d76b3e676acd551f3a9ff4cdd21773eb
3
  size 1824184
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ad6c725009f2e776b99d3134c75f15e11dd7fe75fe4ba1fa94779018c7871f8c
3
  size 1787368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8413f32011996384f13a985a99b4e2f863f8e4717acdb8439b63a10f77db6f15
3
  size 1787368
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:50cb5819ff08a2179d78cd98164d07fd3cef1b66ee7703d599a310dfb140b9d1
3
  size 1824256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9d303b11a0a82e9c51c7b32c7555bd351ec375b1879bf46eb64ea4aff32100f
3
  size 1824256
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c75e42265f382addc71327ad5628e8a2414da5872791c975e384708c4acd549
3
  size 1883352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4334fe8d7157c2a9c85217cb981692daf9eb4c6d3f205d0fd41d4b717daefa1
3
  size 1883352
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (252 Bytes). View file
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc ADDED
Binary file (22 kB). View file
 
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _optimizer_036642a_dirty
3
- ops = torch.ops._optimizer_036642a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_optimizer_036642a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _optimizer_64757cb_dirty
3
+ ops = torch.ops._optimizer_64757cb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_optimizer_64757cb_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β†’ _optimizer_64757cb_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a2363d4311d6a75fbcc03e6d4a71c73dae4d54e00a30135d25198d4078c6b0f
3
  size 1749648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:272fcc69e3774fa43e222efefceeaca97a8c84ee3f1fe528a7478a8e80a70976
3
  size 1749648
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad
torch-ext/optimizer/muon.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
 
4
  import torch
5
  import torch.distributed as dist
6
- from torch.distributed._tensor import DTensor
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
103
 
104
 
105
  @torch.no_grad()
106
- def _scatter(p, state, lr, wd, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
- p.data.mul_(1 - lr * wd)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
 
 
 
 
138
  class Muon(torch.optim.Optimizer):
139
  """
140
  Muon - MomentUm Orthogonalized by Newton-schulz
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
159
  adamw_lr: The learning rate for the internal AdamW.
160
  adamw_betas: The betas for the internal AdamW.
161
  adamw_eps: The epsilon for the internal AdamW.
162
- adamw_wd: The weight decay for the internal AdamW.
163
  """
164
 
165
  def __init__(
166
  self,
167
  model,
168
- is_muon_func,
169
  lr=1e-3,
170
  momentum=0.95,
171
  nesterov=True,
172
  ns_steps=5,
173
- adamw_wd=0.1,
174
  adamw_betas=(0.9, 0.95),
175
  adamw_eps=1e-8,
176
  none_grad=True,
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
178
  ):
179
  defaults = dict(
180
  lr=lr,
181
- wd=adamw_wd,
182
  momentum=momentum,
183
  nesterov=nesterov,
184
  ns_steps=ns_steps,
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
272
 
273
  return param_to_state, ordered_params
274
 
275
- def base(self, params, group, lr, wd, momentum):
276
  # generate weight updates in distributed fashion
277
  for p in params:
278
  g = p.grad
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
299
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
300
 
301
  # apply weight decay
302
- p.data.mul_(1 - lr * wd)
303
 
304
  # apply update
305
  p.data.add_(u, alpha=-adjusted_lr)
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
317
  g = buf
318
  return g
319
 
320
- def _update_p(self, p, u, lr, wd):
321
  # scale update
322
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
323
  # apply weight decay
324
- p.data.mul_(1 - lr * wd)
325
  # apply update
326
  p.data.add_(u, alpha=-adjusted_lr)
327
 
328
- def parallel(self, params, group, lr, wd, momentum):
329
  """
330
  Perform a parallel optimization step using Muon.
331
  """
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
364
  for p in ordered_params[start_idx : start_idx + chunk_size]:
365
  state = param_to_state[id(p)]
366
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
367
- _scatter(p, state, adjusted_lr, wd, self.rank, self.comm_stream)
 
 
368
 
369
  chunk_size = params[0].device_mesh.mesh.numel()
370
 
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
398
 
399
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
400
  lr = group["lr"]
401
- wd = group["wd"]
402
  momentum = group["momentum"]
403
 
404
- if isinstance(params[0].data, DTensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  self.parallel(
406
- params,
407
  group,
408
  lr=lr,
409
- wd=wd,
410
  momentum=momentum,
411
  )
412
- else:
 
413
  self.base(
414
- params,
415
  group,
416
  lr=lr,
417
- wd=wd,
418
  momentum=momentum,
419
  )
420
 
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
426
  lr = group["lr"]
427
  beta1, beta2 = group["adamw_betas"]
428
  eps = group["adamw_eps"]
429
- weight_decay = group["wd"]
430
 
431
  for p in params:
432
  g = p.grad
 
3
 
4
  import torch
5
  import torch.distributed as dist
6
+ from torch.distributed._tensor import DTensor, Replicate
7
 
8
 
9
  # This code snippet is a modified version adapted from the following GitHub repositories:
 
103
 
104
 
105
  @torch.no_grad()
106
+ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
107
  u = state.computed_u
108
  mesh = p.device_mesh
109
 
 
131
  placements=p.placements,
132
  device_mesh=mesh,
133
  )
134
+ p.data.mul_(1 - lr * weight_decay)
135
  p.data.add_(u, alpha=-lr)
136
 
137
 
138
+ def default_is_muon(x, name):
139
+ return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
140
+
141
+
142
  class Muon(torch.optim.Optimizer):
143
  """
144
  Muon - MomentUm Orthogonalized by Newton-schulz
 
163
  adamw_lr: The learning rate for the internal AdamW.
164
  adamw_betas: The betas for the internal AdamW.
165
  adamw_eps: The epsilon for the internal AdamW.
166
+ adamw_weight_decay: The weight decay for the internal AdamW.
167
  """
168
 
169
  def __init__(
170
  self,
171
  model,
172
+ is_muon_func=default_is_muon,
173
  lr=1e-3,
174
  momentum=0.95,
175
  nesterov=True,
176
  ns_steps=5,
177
+ weight_decay=0.1,
178
  adamw_betas=(0.9, 0.95),
179
  adamw_eps=1e-8,
180
  none_grad=True,
 
182
  ):
183
  defaults = dict(
184
  lr=lr,
185
+ weight_decay=weight_decay,
186
  momentum=momentum,
187
  nesterov=nesterov,
188
  ns_steps=ns_steps,
 
276
 
277
  return param_to_state, ordered_params
278
 
279
+ def base(self, params, group, lr, weight_decay, momentum):
280
  # generate weight updates in distributed fashion
281
  for p in params:
282
  g = p.grad
 
303
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
304
 
305
  # apply weight decay
306
+ p.data.mul_(1 - lr * weight_decay)
307
 
308
  # apply update
309
  p.data.add_(u, alpha=-adjusted_lr)
 
321
  g = buf
322
  return g
323
 
324
+ def _update_p(self, p, u, lr, weight_decay):
325
  # scale update
326
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
327
  # apply weight decay
328
+ p.data.mul_(1 - lr * weight_decay)
329
  # apply update
330
  p.data.add_(u, alpha=-adjusted_lr)
331
 
332
+ def parallel(self, params, group, lr, weight_decay, momentum):
333
  """
334
  Perform a parallel optimization step using Muon.
335
  """
 
368
  for p in ordered_params[start_idx : start_idx + chunk_size]:
369
  state = param_to_state[id(p)]
370
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
371
+ _scatter(
372
+ p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
373
+ )
374
 
375
  chunk_size = params[0].device_mesh.mesh.numel()
376
 
 
404
 
405
  params = [p for p in group["params"] if self.state[p]["use_muon"]]
406
  lr = group["lr"]
407
+ weight_decay = group["weight_decay"]
408
  momentum = group["momentum"]
409
 
410
+ param_dtensors = []
411
+ param_tensors = []
412
+
413
+ for p in params:
414
+ if p is None or p.grad is None:
415
+ continue
416
+ if isinstance(p.data, DTensor):
417
+ if all(
418
+ isinstance(placement, Replicate) for placement in p.placements
419
+ ):
420
+ param_tensors.append(p)
421
+ else:
422
+ param_dtensors.append(p)
423
+ elif isinstance(p.data, torch.Tensor):
424
+ param_tensors.append(p)
425
+ else:
426
+ raise TypeError(f"Unsupported parameter type: {type(p.data)}")
427
+
428
+ if self.debug:
429
+ print(
430
+ f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
431
+ flush=True,
432
+ )
433
+
434
+ if len(param_dtensors) > 0:
435
  self.parallel(
436
+ param_dtensors,
437
  group,
438
  lr=lr,
439
+ weight_decay=weight_decay,
440
  momentum=momentum,
441
  )
442
+
443
+ if len(param_tensors) > 0:
444
  self.base(
445
+ param_tensors,
446
  group,
447
  lr=lr,
448
+ weight_decay=weight_decay,
449
  momentum=momentum,
450
  )
451
 
 
457
  lr = group["lr"]
458
  beta1, beta2 = group["adamw_betas"]
459
  eps = group["adamw_eps"]
460
+ weight_decay = group["weight_decay"]
461
 
462
  for p in params:
463
  g = p.grad