Commit
Β·
02ac540
1
Parent(s):
64757cb
refactor(muon): change argument adam_wd to weight_decay and handle params' type
Browse files- build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py +52 -21
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py +52 -21
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc +0 -0
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so} +1 -1
- build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py +52 -21
- torch-ext/optimizer/muon.py +52 -21
build/torch26-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787272
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1f5df341112d93e43c0801e285abd66e79bfbe399d228f8be09ff26ece7421b
|
3 |
size 1787272
|
build/torch26-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824224
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2921aa2aa2587e261dc9ca4e5f60303b0d1c9a305d1584918a8c56b6dc79ebfb
|
3 |
size 1824224
|
build/torch26-cxx11-cu124-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824224
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a93530e6981fdac23236dd7e3657c5b47513cda4accec78293234ce5f233400b
|
3 |
size 1824224
|
build/torch26-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1749744
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:caa40905ac8f209fecccae42c6892c3766ad5c7069382e60d2339e73da6ee7d6
|
3 |
size 1749744
|
build/torch26-cxx11-rocm62-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787192
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6919551ed599e7e0dc1a750d1972bdb31605f57583b3617054cb70dd40d54d26
|
3 |
size 1787192
|
build/torch26-cxx98-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824184
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f07cc2637669130fc9e209cb2c4358caba1c4c2d5837a108043b073d7897c3a7
|
3 |
size 1824184
|
build/torch26-cxx98-cu124-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824184
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b9ef8fa2dd4d80cb3c1c3c2a72b99e0d76b3e676acd551f3a9ff4cdd21773eb
|
3 |
size 1824184
|
build/torch26-cxx98-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1787368
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8413f32011996384f13a985a99b4e2f863f8e4717acdb8439b63a10f77db6f15
|
3 |
size 1787368
|
build/torch27-cxx11-cu118-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1824256
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9d303b11a0a82e9c51c7b32c7555bd351ec375b1879bf46eb64ea4aff32100f
|
3 |
size 1824256
|
build/torch27-cxx11-cu126-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1883352
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4334fe8d7157c2a9c85217cb981692daf9eb4c6d3f205d0fd41d4b717daefa1
|
3 |
size 1883352
|
build/torch27-cxx11-cu128-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (252 Bytes). View file
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/__pycache__/muon.cpython-312.pyc
ADDED
Binary file (22 kB). View file
|
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _optimizer_64757cb_dirty
|
3 |
+
ops = torch.ops._optimizer_64757cb_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_optimizer_64757cb_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/{_optimizer_036642a_dirty.abi3.so β _optimizer_64757cb_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1749648
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:272fcc69e3774fa43e222efefceeaca97a8c84ee3f1fe528a7478a8e80a70976
|
3 |
size 1749648
|
build/torch27-cxx11-rocm63-x86_64-linux/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|
torch-ext/optimizer/muon.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
-
from torch.distributed._tensor import DTensor
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
@@ -103,7 +103,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
-
def _scatter(p, state, lr,
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
@@ -131,10 +131,14 @@ def _scatter(p, state, lr, wd, rank, comm_stream):
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
-
p.data.mul_(1 - lr *
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
138 |
class Muon(torch.optim.Optimizer):
|
139 |
"""
|
140 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
@@ -159,18 +163,18 @@ class Muon(torch.optim.Optimizer):
|
|
159 |
adamw_lr: The learning rate for the internal AdamW.
|
160 |
adamw_betas: The betas for the internal AdamW.
|
161 |
adamw_eps: The epsilon for the internal AdamW.
|
162 |
-
|
163 |
"""
|
164 |
|
165 |
def __init__(
|
166 |
self,
|
167 |
model,
|
168 |
-
is_muon_func,
|
169 |
lr=1e-3,
|
170 |
momentum=0.95,
|
171 |
nesterov=True,
|
172 |
ns_steps=5,
|
173 |
-
|
174 |
adamw_betas=(0.9, 0.95),
|
175 |
adamw_eps=1e-8,
|
176 |
none_grad=True,
|
@@ -178,7 +182,7 @@ class Muon(torch.optim.Optimizer):
|
|
178 |
):
|
179 |
defaults = dict(
|
180 |
lr=lr,
|
181 |
-
|
182 |
momentum=momentum,
|
183 |
nesterov=nesterov,
|
184 |
ns_steps=ns_steps,
|
@@ -272,7 +276,7 @@ class Muon(torch.optim.Optimizer):
|
|
272 |
|
273 |
return param_to_state, ordered_params
|
274 |
|
275 |
-
def base(self, params, group, lr,
|
276 |
# generate weight updates in distributed fashion
|
277 |
for p in params:
|
278 |
g = p.grad
|
@@ -299,7 +303,7 @@ class Muon(torch.optim.Optimizer):
|
|
299 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
300 |
|
301 |
# apply weight decay
|
302 |
-
p.data.mul_(1 - lr *
|
303 |
|
304 |
# apply update
|
305 |
p.data.add_(u, alpha=-adjusted_lr)
|
@@ -317,15 +321,15 @@ class Muon(torch.optim.Optimizer):
|
|
317 |
g = buf
|
318 |
return g
|
319 |
|
320 |
-
def _update_p(self, p, u, lr,
|
321 |
# scale update
|
322 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
323 |
# apply weight decay
|
324 |
-
p.data.mul_(1 - lr *
|
325 |
# apply update
|
326 |
p.data.add_(u, alpha=-adjusted_lr)
|
327 |
|
328 |
-
def parallel(self, params, group, lr,
|
329 |
"""
|
330 |
Perform a parallel optimization step using Muon.
|
331 |
"""
|
@@ -364,7 +368,9 @@ class Muon(torch.optim.Optimizer):
|
|
364 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
365 |
state = param_to_state[id(p)]
|
366 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
367 |
-
_scatter(
|
|
|
|
|
368 |
|
369 |
chunk_size = params[0].device_mesh.mesh.numel()
|
370 |
|
@@ -398,23 +404,48 @@ class Muon(torch.optim.Optimizer):
|
|
398 |
|
399 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
400 |
lr = group["lr"]
|
401 |
-
|
402 |
momentum = group["momentum"]
|
403 |
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
self.parallel(
|
406 |
-
|
407 |
group,
|
408 |
lr=lr,
|
409 |
-
|
410 |
momentum=momentum,
|
411 |
)
|
412 |
-
|
|
|
413 |
self.base(
|
414 |
-
|
415 |
group,
|
416 |
lr=lr,
|
417 |
-
|
418 |
momentum=momentum,
|
419 |
)
|
420 |
|
@@ -426,7 +457,7 @@ class Muon(torch.optim.Optimizer):
|
|
426 |
lr = group["lr"]
|
427 |
beta1, beta2 = group["adamw_betas"]
|
428 |
eps = group["adamw_eps"]
|
429 |
-
weight_decay = group["
|
430 |
|
431 |
for p in params:
|
432 |
g = p.grad
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
+
from torch.distributed._tensor import DTensor, Replicate
|
7 |
|
8 |
|
9 |
# This code snippet is a modified version adapted from the following GitHub repositories:
|
|
|
103 |
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
107 |
u = state.computed_u
|
108 |
mesh = p.device_mesh
|
109 |
|
|
|
131 |
placements=p.placements,
|
132 |
device_mesh=mesh,
|
133 |
)
|
134 |
+
p.data.mul_(1 - lr * weight_decay)
|
135 |
p.data.add_(u, alpha=-lr)
|
136 |
|
137 |
|
138 |
+
def default_is_muon(x, name):
|
139 |
+
return x.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name
|
140 |
+
|
141 |
+
|
142 |
class Muon(torch.optim.Optimizer):
|
143 |
"""
|
144 |
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
|
163 |
adamw_lr: The learning rate for the internal AdamW.
|
164 |
adamw_betas: The betas for the internal AdamW.
|
165 |
adamw_eps: The epsilon for the internal AdamW.
|
166 |
+
adamw_weight_decay: The weight decay for the internal AdamW.
|
167 |
"""
|
168 |
|
169 |
def __init__(
|
170 |
self,
|
171 |
model,
|
172 |
+
is_muon_func=default_is_muon,
|
173 |
lr=1e-3,
|
174 |
momentum=0.95,
|
175 |
nesterov=True,
|
176 |
ns_steps=5,
|
177 |
+
weight_decay=0.1,
|
178 |
adamw_betas=(0.9, 0.95),
|
179 |
adamw_eps=1e-8,
|
180 |
none_grad=True,
|
|
|
182 |
):
|
183 |
defaults = dict(
|
184 |
lr=lr,
|
185 |
+
weight_decay=weight_decay,
|
186 |
momentum=momentum,
|
187 |
nesterov=nesterov,
|
188 |
ns_steps=ns_steps,
|
|
|
276 |
|
277 |
return param_to_state, ordered_params
|
278 |
|
279 |
+
def base(self, params, group, lr, weight_decay, momentum):
|
280 |
# generate weight updates in distributed fashion
|
281 |
for p in params:
|
282 |
g = p.grad
|
|
|
303 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
304 |
|
305 |
# apply weight decay
|
306 |
+
p.data.mul_(1 - lr * weight_decay)
|
307 |
|
308 |
# apply update
|
309 |
p.data.add_(u, alpha=-adjusted_lr)
|
|
|
321 |
g = buf
|
322 |
return g
|
323 |
|
324 |
+
def _update_p(self, p, u, lr, weight_decay):
|
325 |
# scale update
|
326 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
327 |
# apply weight decay
|
328 |
+
p.data.mul_(1 - lr * weight_decay)
|
329 |
# apply update
|
330 |
p.data.add_(u, alpha=-adjusted_lr)
|
331 |
|
332 |
+
def parallel(self, params, group, lr, weight_decay, momentum):
|
333 |
"""
|
334 |
Perform a parallel optimization step using Muon.
|
335 |
"""
|
|
|
368 |
for p in ordered_params[start_idx : start_idx + chunk_size]:
|
369 |
state = param_to_state[id(p)]
|
370 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
371 |
+
_scatter(
|
372 |
+
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
373 |
+
)
|
374 |
|
375 |
chunk_size = params[0].device_mesh.mesh.numel()
|
376 |
|
|
|
404 |
|
405 |
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
406 |
lr = group["lr"]
|
407 |
+
weight_decay = group["weight_decay"]
|
408 |
momentum = group["momentum"]
|
409 |
|
410 |
+
param_dtensors = []
|
411 |
+
param_tensors = []
|
412 |
+
|
413 |
+
for p in params:
|
414 |
+
if p is None or p.grad is None:
|
415 |
+
continue
|
416 |
+
if isinstance(p.data, DTensor):
|
417 |
+
if all(
|
418 |
+
isinstance(placement, Replicate) for placement in p.placements
|
419 |
+
):
|
420 |
+
param_tensors.append(p)
|
421 |
+
else:
|
422 |
+
param_dtensors.append(p)
|
423 |
+
elif isinstance(p.data, torch.Tensor):
|
424 |
+
param_tensors.append(p)
|
425 |
+
else:
|
426 |
+
raise TypeError(f"Unsupported parameter type: {type(p.data)}")
|
427 |
+
|
428 |
+
if self.debug:
|
429 |
+
print(
|
430 |
+
f"[Muon] {len(param_dtensors)} DTensors, {len(param_tensors)} Tensors",
|
431 |
+
flush=True,
|
432 |
+
)
|
433 |
+
|
434 |
+
if len(param_dtensors) > 0:
|
435 |
self.parallel(
|
436 |
+
param_dtensors,
|
437 |
group,
|
438 |
lr=lr,
|
439 |
+
weight_decay=weight_decay,
|
440 |
momentum=momentum,
|
441 |
)
|
442 |
+
|
443 |
+
if len(param_tensors) > 0:
|
444 |
self.base(
|
445 |
+
param_tensors,
|
446 |
group,
|
447 |
lr=lr,
|
448 |
+
weight_decay=weight_decay,
|
449 |
momentum=momentum,
|
450 |
)
|
451 |
|
|
|
457 |
lr = group["lr"]
|
458 |
beta1, beta2 = group["adamw_betas"]
|
459 |
eps = group["adamw_eps"]
|
460 |
+
weight_decay = group["weight_decay"]
|
461 |
|
462 |
for p in params:
|
463 |
g = p.grad
|