[trainer] fix: reproducible problem when resume training#4156
[trainer] fix: reproducible problem when resume training#4156wuxibin89 merged 6 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request aims to fix a reproducibility issue when resuming training by correctly restoring the sampler's state. The changes correctly calculate the epoch to resume from and modify the training loop to start at that epoch. However, the implementation for restoring the sampler's state is highly inefficient and contains a correctness bug. My review includes a critical comment with a suggested code change to address both the inefficiency by using set_epoch where available, and to fix an off-by-one error in the loop logic that would cause incorrect data sampling upon resuming.
verl/trainer/ppo/ray_trainer.py
Outdated
| for _ in range(current_epoch - 1): | ||
| for _ in iter(self.train_dataloader.sampler): | ||
| pass |
There was a problem hiding this comment.
This approach to restoring the sampler state is very inefficient and seems to have a correctness issue.
-
Correctness: The loop
for _ in range(current_epoch - 1):appears to have an off-by-one error. If training resumes at aglobal_stepcorresponding to the start ofcurrent_epoch=1, this loop (range(0)) will not execute. The main training loop will then start at epoch 1, but the sampler will still be in its state for epoch 0. This would lead to incorrect data sampling and break the reproducibility you're trying to fix. To correctly advance the sampler to be ready forcurrent_epoch, the loop should runcurrent_epochtimes. -
Inefficiency: The nested loop iterates through all the samples in the dataset for each epoch you want to skip. For large datasets, this can add a significant delay to your training startup time.
A more robust and efficient solution would be to use the set_epoch() method if the sampler supports it (which is standard for torch.utils.data.distributed.DistributedSampler). This avoids iterating through the dataset entirely.
I suggest replacing this block with a more efficient and correct implementation that handles both points.
| for _ in range(current_epoch - 1): | |
| for _ in iter(self.train_dataloader.sampler): | |
| pass | |
| if hasattr(self.train_dataloader.sampler, "set_epoch"): | |
| # More efficient and standard way for DistributedSampler | |
| self.train_dataloader.sampler.set_epoch(current_epoch) | |
| else: | |
| # Fallback for samplers without set_epoch, with correction. | |
| # This is inefficient and should be avoided if possible. | |
| for _ in range(current_epoch): | |
| for _ in self.train_dataloader.sampler: | |
| pass |
There was a problem hiding this comment.
Please address gemini review, set_epoch is more efficient.
There was a problem hiding this comment.
The sampler appears to be one epoch ahead after resuming from _load_checkpoint (e.g., it resumes at step 3 instead of step 1). As a workaround, we only advance the sampler current_epoch - 1 times.
There was a problem hiding this comment.
Please address gemini review,
set_epochis more efficient.
@wuxibin89 Sorry for the miss.
For the inefficiency problem, the only sampler that currently matches the conditionhasattr(self.train_dataloader.sampler, "set_epoch")isDistributedSampler. We do not use it in create_rl_sampler. This will only have an effect if we callset_epoch()between epochs, which we are not doing currently.
For Correctness Problem, you could see the comment below.
There was a problem hiding this comment.
It's torch.utils.data.RandomSampler duty to skip sample consumed in this epoch so far, there's no need to skip it manually.
https://github.com/meta-pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L73-L74
There was a problem hiding this comment.
Sorry for miss this comment again(magic network).
We used torch.utils.data.sampler.RandomSampler rather than torchdata.stateful_dataloader.sampler.RandomSampler. I am trying the second one, seems it will fit StatefulDataloader
|
@eric-haibin-lin @vermouth1992 @PeterSH6 @tongyx361 Could anyone help me review this? |
| next_step_profile = False | ||
|
|
||
| for epoch in range(self.config.trainer.total_epochs): | ||
| for epoch in range(current_epoch, self.config.trainer.total_epochs): |
There was a problem hiding this comment.
From https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.
I think, we miss set_epoch() here.
There was a problem hiding this comment.
It works only when we use DistributedSampler, should I change sampler from RandomSampler/SequentialSampler to DistributedSampler?
There was a problem hiding this comment.
No, it works for both RandomSampler/SequentialSampler and DistributedSampler, see comment above.
There was a problem hiding this comment.
https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
You could see the code of DistributedSampler; it set generator by g.manual_seed(self.seed + self.epoch)
But in RandomSampler, it only based on seed, this sampler doesn't has the property(self.epoch)
In verl colde: creat_rl_sampler get generator(seed) for data_config.
|
@wlhgtc I think we should use import os
import torch
from torch.utils.data import TensorDataset
from torchdata.stateful_dataloader import StatefulDataLoader
if os.environ.get("USE_TORCHDATA", "0") == "1":
from torchdata.stateful_dataloader.sampler import RandomSampler
print(f"Using torchdata.stateful_dataloader.sampler.RandomSampler")
else:
from torch.utils.data import RandomSampler
print(f"Using torch.utils.data.RandomSampler")
# Set seed for reproducibility
torch.manual_seed(42)
dataset = TensorDataset(torch.arange(10))
sampler = RandomSampler(dataset)
dataloader = StatefulDataLoader(dataset, num_workers=2, batch_size=4, drop_last=True, sampler=sampler)
global_step = 0
total_epoch = 3
for epoch in range(total_epoch):
should_break = False
for i, batch in enumerate(dataloader):
global_step += 1
print(f"{epoch=}, {i=}, {global_step=}: {batch}")
# Save state_dict at step 0 of epoch 1
if epoch == 1 and i == 0:
state_dict = dataloader.state_dict()
should_break = True
break
if should_break:
break
start_epoch = global_step // len(dataloader)
print(f"recover from epoch {start_epoch}")
del sampler, dataloader
# Training run resumes with the previous checkpoint
sampler2 = RandomSampler(dataset)
dataloader2 = StatefulDataLoader(dataset, num_workers=2, batch_size=4, drop_last=True, sampler=sampler2)
# Resume state with DataLoader
dataloader2.load_state_dict(state_dict)
for epoch in range(start_epoch, total_epoch):
for i, batch in enumerate(dataloader2):
global_step += 1
print(f"{epoch=}, {i=}, {global_step=}: {batch}")The |
Thanks for your advice, I update my code and any other suggestions? @wuxibin89 |
|
@wlhgtc Could you please verify the change on your task? |
|
…t#4156) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Fix problem in verl-project#3457 When resuming training, we didn't calculate the correct epoch number to restart from. The StatefulDataLoader only saves the training step(in a epoch), not the specific step within a particular epoch. This caused the sampler to always default to epoch 0 upon resuming, which breaks reproducibility. This PR fixes this by making two changes: 1. Calculating and restoring the correct current epoch number. 2. Restoring the sampler's state correctly. As mentioned by @wuxibin89 , we should use `torchdata.stateful_dataloader.sampler` instead of `torch.utils.data.RandomSampler` <img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54 56@2x" src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5" /> ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: …… - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: …… - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
…t#4156) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Fix problem in verl-project#3457 When resuming training, we didn't calculate the correct epoch number to restart from. The StatefulDataLoader only saves the training step(in a epoch), not the specific step within a particular epoch. This caused the sampler to always default to epoch 0 upon resuming, which breaks reproducibility. This PR fixes this by making two changes: 1. Calculating and restoring the correct current epoch number. 2. Restoring the sampler's state correctly. As mentioned by @wuxibin89 , we should use `torchdata.stateful_dataloader.sampler` instead of `torch.utils.data.RandomSampler` <img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54 56@2x" src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5" /> ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: …… - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: …… - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
…t#4156) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Fix problem in verl-project#3457 When resuming training, we didn't calculate the correct epoch number to restart from. The StatefulDataLoader only saves the training step(in a epoch), not the specific step within a particular epoch. This caused the sampler to always default to epoch 0 upon resuming, which breaks reproducibility. This PR fixes this by making two changes: 1. Calculating and restoring the correct current epoch number. 2. Restoring the sampler's state correctly. As mentioned by @wuxibin89 , we should use `torchdata.stateful_dataloader.sampler` instead of `torch.utils.data.RandomSampler` <img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54 56@2x" src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5" /> ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: …… - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: …… - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
…t#4156) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Fix problem in verl-project#3457 When resuming training, we didn't calculate the correct epoch number to restart from. The StatefulDataLoader only saves the training step(in a epoch), not the specific step within a particular epoch. This caused the sampler to always default to epoch 0 upon resuming, which breaks reproducibility. This PR fixes this by making two changes: 1. Calculating and restoring the correct current epoch number. 2. Restoring the sampler's state correctly. As mentioned by @wuxibin89 , we should use `torchdata.stateful_dataloader.sampler` instead of `torch.utils.data.RandomSampler` <img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54 56@2x" src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5" /> ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: …… - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: …… - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>

What does this PR do?
Fix problem in #3457
When resuming training, we didn't calculate the correct epoch number to restart from. The StatefulDataLoader only saves the training step(in a epoch), not the specific step within a particular epoch.
This caused the sampler to always default to epoch 0 upon resuming, which breaks reproducibility.
This PR fixes this by making two changes:
As mentioned by @wuxibin89 , we should use

torchdata.stateful_dataloader.samplerinstead oftorch.utils.data.RandomSamplerChecklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)