-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[trainer] fix: reproducible problem when resume training #4156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
d8721e9
555e36f
f8a1bdc
ccff31a
a1658ed
59dd68e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -975,6 +975,11 @@ def fit(self): | |
|
|
||
| # load checkpoint before doing anything | ||
| self._load_checkpoint() | ||
| # resume sampler state if needed | ||
| current_epoch = self.global_steps // len(self.train_dataloader) | ||
| for _ in range(current_epoch - 1): | ||
| for _ in iter(self.train_dataloader.sampler): | ||
| pass | ||
|
|
||
| # perform validation before training | ||
| # currently, we only support validation using the reward_function. | ||
|
|
@@ -1006,7 +1011,7 @@ def fit(self): | |
| ) | ||
| next_step_profile = False | ||
|
|
||
| for epoch in range(self.config.trainer.total_epochs): | ||
| for epoch in range(current_epoch, self.config.trainer.total_epochs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
I think, we miss
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It works only when we use
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it works for both
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You could see the code of But in In verl colde: creat_rl_sampler get generator(seed) for |
||
| for batch_dict in self.train_dataloader: | ||
| metrics = {} | ||
| timing_raw = {} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach to restoring the sampler state is very inefficient and seems to have a correctness issue.
Correctness: The loop
for _ in range(current_epoch - 1):appears to have an off-by-one error. If training resumes at aglobal_stepcorresponding to the start ofcurrent_epoch=1, this loop (range(0)) will not execute. The main training loop will then start at epoch 1, but the sampler will still be in its state for epoch 0. This would lead to incorrect data sampling and break the reproducibility you're trying to fix. To correctly advance the sampler to be ready forcurrent_epoch, the loop should runcurrent_epochtimes.Inefficiency: The nested loop iterates through all the samples in the dataset for each epoch you want to skip. For large datasets, this can add a significant delay to your training startup time.
A more robust and efficient solution would be to use the
set_epoch()method if the sampler supports it (which is standard fortorch.utils.data.distributed.DistributedSampler). This avoids iterating through the dataset entirely.I suggest replacing this block with a more efficient and correct implementation that handles both points.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address gemini review,
set_epochis more efficient.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sampler appears to be one epoch ahead after resuming from _load_checkpoint (e.g., it resumes at step 3 instead of step 1). As a workaround, we only advance the sampler current_epoch - 1 times.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Correctness Problem, you could see the comment below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's
torch.utils.data.RandomSamplerduty to skip sample consumed in this epoch so far, there's no need to skip it manually.https://github.com/meta-pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L73-L74
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for miss this comment again(magic network).
We used
torch.utils.data.sampler.RandomSamplerrather thantorchdata.stateful_dataloader.sampler.RandomSampler. I am trying the second one, seems it will fitStatefulDataloader