-
Notifications
You must be signed in to change notification settings - Fork 514
New WR (-15 steps/-0.7s): Implement NorMuon on latest version #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I think it'll also work on the medium track, perhaps even better than here. |
|
Glad to see NorMuon added to the latest run! I'll validate and plan to merge next week. The readme will preserve the 0.7s improvement, controlling for the typical 1-2s hardware differences. |
|
The smaller gain here compared to #141 may partially be due to the 33% smaller batch size of 262144 instead of 393216. Now that the optimizer has been updated, this could be a promising hyperparameter to revisit. |
|
Thank you! I'll try to further tune that. |
|
Brainstorming some thoughts for future explorations/PRs: I looked more into the loss curve from the log files, using the 5 logs from the last record and the first 10 new logs. NorMuon is following a different trajectory on the loss curve, which indicates to me that there may be an opportunity to revisit certain params. Before learning rate decay kicks in around step 1250, the loss is quite a bit lower. (But I'm cautious to interpret this on its own too much because smaller lr can cause lower early loss that doesn't translate to better final loss) Then things look somewhat similar around step 2250. NorMuon appears to be performing slightly better per step on the last 50 steps, considering it stops 15 steps early. Muon has a linear momentum warmup from steps 0 to 300 and cooldown for last 50 steps. This coincides with when NorMuon drops loss faster- could be worth looking at momentum value and the warmup/cooldown schedule. Cooldown did not exist in #141. The other params of interest here are NorMuon lr=0.06 and cooldown_frac = 0.45. AdamW and Muon have always used the same cooldown_frac param, but maybe these should be on different schedules. Another thing I am curious about is if the impact is driven by a subset of layers, or subset of module types (attn/mlp/gate), or subset of steps. There could also be interactions going on when the window size updates from 3 to 7 and 7 to 11 with the momentum terms. These are sharper jumps than the 1,2,3,4,5...,13 gradual increase in #141. Would need to print out validation loss right before and after these jumps happen to compare. |
|
@ClassicLarry @zichongli5 do you see those hyperparam changes + potential batch size exploration as part of this PR? otherwise I'd like to contribute a PR "on top of this one" that changes some LR logic |
|
Thanks for the detailed loss-curve read—agree with the takeaways. I’m running some quick tuning accordingly (momentum warmup/cooldown, decoupling cooldown_frac for NorMuon). If these land clear wins, I think we can fold them as a separate PR. @varunneal A new PR sounds great to me! I can incorporate my tuning changes (if any) with your new PR later. |
|
Validated and merged. Bumped up prior record in readme to match the times here, since it will be better to keep the runtimes closer to what people are testing on and these times are in-line with what I got from a rerun. Prior record was on atypically fast machine. |
|
I am noticing that the norm step is getting applied to the attention parameters after they get reshaped back to size (768, 768*4), whereas it ideally should be applied to size (4, 768, 768). This gets corrected in #146 during the Muon updates, and may partially explain the additional improvement there. Perhaps another reason why the gain here is smaller than in #141. |
|
Thanks for merging! And good catch on the norm shape—that's more reasonable and should work better. Additionally, during tuning I notice that in this PR the LR ends up “const → linear decay to 0.1x” as the extension iterations are only used in # get_lr
x = min(0.9999, step / args.num_iterations)It probably should mirror # get_lr (proposed)
x = min(step / (1 + args.num_scheduled_iterations), 0.9999)So seems like the LR doesn't gets the final flat phase? which could also be a reason for #146 improvements. |
Nice find. This was introduced in a refactor after the prior PR, so this PR should be the only one affected, which will get corrected in 146. |

Implement NorMuon Optimizer (Neuron-wise Normalized Muon) on latest branch.
This PR adds NorMuon, also implemented in #141, on top of the latest branch (up to #140). All hyperparameters stay the same and only the step count is changed from 2330->2315.
Results
We run 20 runs for current Muon and NorMuon on 8xH100.
Muon:
NorMuon:
We notice the mean runtime of Muon is 141.45s, slightly higher than previously reported, could be due to some hardware differences. Under fair comparison, NorMuon gets -0.7s and 15 steps fewer steps.