dongseokmotif commited on
Commit
671b033
·
verified ·
1 Parent(s): 8447fd1

fix(optimizer): resolve bug where weight decay was multiplied by wrong lr value (#5)

Browse files

- fix(optimizer): resolve bug where weight decay was multiplied by wrong lr value (33568272536def07cf6095dadff78ca2d4b182b5)

Files changed (1) hide show
  1. torch-ext/optimizer/muon.py +3 -3
torch-ext/optimizer/muon.py CHANGED
@@ -104,7 +104,7 @@ def _compute_u(state, steps, rank, compute_stream):
104
 
105
 
106
  @torch.no_grad()
107
- def _scatter(p, state, lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
109
 
110
  with torch.cuda.stream(comm_stream):
@@ -133,7 +133,7 @@ def _scatter(p, state, lr, weight_decay, rank, comm_stream):
133
  device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
- p.data.add_(u, alpha=-lr)
137
 
138
 
139
  def default_is_muon(x, name):
@@ -387,7 +387,7 @@ class Muon(torch.optim.Optimizer):
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
  _scatter(
390
- p, state, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
  chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)
 
104
 
105
 
106
  @torch.no_grad()
107
+ def _scatter(p, state, lr, adjusted_lr, weight_decay, rank, comm_stream):
108
  u = state.computed_u
109
 
110
  with torch.cuda.stream(comm_stream):
 
133
  device_mesh=p.device_mesh,
134
  )
135
  p.data.mul_(1 - lr * weight_decay)
136
+ p.data.add_(u, alpha=-adjusted_lr)
137
 
138
 
139
  def default_is_muon(x, name):
 
387
  state = param_to_state[id(p)]
388
  adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
389
  _scatter(
390
+ p, state, lr, adjusted_lr, weight_decay, self.rank, self.comm_stream
391
  )
392
 
393
  chunk_size = dist.get_world_size(param_to_state[id(params[0])].process_group)