fix(optimizer): resolve bug where weight decay was multiplied by wrong lr value (#5)
Browse files- fix(optimizer): resolve bug where weight decay was multiplied by wrong lr value (33568272536def07cf6095dadff78ca2d4b182b5)
torch-ext/optimizer/muon.py
CHANGED
@@ -104,7 +104,7 @@ def _compute_u(state, steps, rank, compute_stream):
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
-
def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
@@ -133,7 +133,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
|
|
133 |
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
-
p.data.add_(u, alpha=-
|
137 |
|
138 |
|
139 |
def default_is_muon(x, name):
|
@@ -387,7 +387,7 @@ class Muon(torch.optim.Optimizer):
|
|
387 |
state = param_to_state[id(p)]
|
388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
389 |
_scatter(
|
390 |
-
p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|
|
|
104 |
|
105 |
|
106 |
@torch.no_grad()
|
107 |
+
def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
|
108 |
u = state.computed_u
|
109 |
|
110 |
with torch.cuda.stream(comm_stream):
|
|
|
133 |
device_mesh=p.device_mesh,
|
134 |
)
|
135 |
p.data.mul_(1 - lr * weight_decay)
|
136 |
+
p.data.add_(u, alpha=-adjusted_lr)
|
137 |
|
138 |
|
139 |
def default_is_muon(x, name):
|
|
|
387 |
state = param_to_state[id(p)]
|
388 |
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
|
389 |
_scatter(
|
390 |
+
p, state, lr, adjusted_lr, weight_decay, self.rank, self.comm_stream
|
391 |
)
|
392 |
|
393 |
chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
|