Initial implementation of power scheduler#974
Conversation
|
yeah it's probably time to make a more configurable schedule thing... |
dlwh
left a comment
There was a problem hiding this comment.
seems good modulo some nits (I didn't check the math)
I don't love that you'd have to specify batch size twice (and it doesn't take into account batch schedule), but I really don't want to thread it through.
Do you want to give this a shot using the Marin "submodule" stuff and see how it does?
Mhmm. I think I can rewrite so that it takes into account BatchSchedule, but it's probably better to try and decouple the optimizer and scheduler first (which I can take on if you can tell me the direction on how to do it).
Sure |
|
ok, you asked for it! so scheduler refactor is going to be a bit delicate just since we need to support old-style configs because I don't want to break too much compatibility (and we probably want to keep the simple case simple). I'm going to describe it in its full glory, but I think we can split it into chunks, and you only really need to do the first one. In an ideal world (based on our experience with Marin 8B), I think the schedule system would look something like lr_schedule: str | LrSchedule | list[ScheduleStep[LrSchedule]]If lr_schedule is a string, use old behavior. Otherwise, use new behavior. Chunk 1: LrScheduleThis will enable lr_schedule: str | LrScheduleLrSchedule should probably look a bit like how we do the OptimConfig, which is to say a draccus plugin registry. We should make a LrSchedule type that is a PluginRegistry and add the kinds we already have. LrSchedule should probably have all settings needed to specify it, except for the schedule length, which should be figured out from context. normal Cosine would be specified as either: optimizer:
lr_schedule: "cosine"
warmup: ...
learining_rate: ... or optimizer:
lr_schedule:
type: cosine
learning_rate: <peak lr>
exponent: ...for more fine-grained control... Chunk 2: Staged SchedulesWe should support staged schedules too: optimizer:
lr_schedule:
# warmup
- start: 0
value:
type: linear
initial_lr: 0.0
final_lr: 1e-4
# stable
- start: 1000
value:
type: constant
learning_rate: 1e-4
# decay
- start: 100000
value:
type: invThis is obviously a pain to specify for the common case but sometimes we need the control This is a big lift, so lmk if this is too daunting! Like I said, we only really need the first chunk right now |
|
I implemented chunk 1. I passed in ctx: LrScheduleContext for LrSchedule to avoid the user being able to re-declare learning_rate and min_lr_ratio, which may lead to hidden bugs. I will probably revisit and implement chunk 2 some other times. |
Implementation based on this paper. https://arxiv.org/abs/2408.13359. @dlwh @Helw150 I'm not sure if this is the best way to go about implementing this as is. I had to add five more parameters to the parameter list of the OptimizerConfig class to implement Power Scheduler, as, based on my understanding, there is a tight coupling between the optimizer and the scheduler. Should there be an effort to decouple the optimizer and the scheduler, so that we don't have to pass schedule-specific parameters directly into OptimizerConfig?
Implementation based on this paper. https://arxiv.org/abs/2408.13359.
@dlwh @Helw150 I'm not sure if this is the best way to go about implementing this as is. I had to add five more parameters to the parameter list of the OptimizerConfig class to implement Power Scheduler, as, based on my understanding, there is a tight coupling between the optimizer and the scheduler. Should there be an effort to decouple the optimizer and the scheduler, so that we don't have to pass schedule-specific parameters directly into OptimizerConfig?