Skip to content

[trainer] fix: reproducible problem when resume training#4156

Merged
wuxibin89 merged 6 commits intoverl-project:mainfrom
wlhgtc:main
Nov 17, 2025
Merged

[trainer] fix: reproducible problem when resume training#4156
wuxibin89 merged 6 commits intoverl-project:mainfrom
wlhgtc:main

Conversation

@wlhgtc
Copy link
Contributor

@wlhgtc wlhgtc commented Nov 17, 2025

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Fix problem in #3457
When resuming training, we didn't calculate the correct epoch number to restart from. The StatefulDataLoader only saves the training step(in a epoch), not the specific step within a particular epoch.

This caused the sampler to always default to epoch 0 upon resuming, which breaks reproducibility.

This PR fixes this by making two changes:

  1. Calculating and restoring the correct current epoch number.
  2. Restoring the sampler's state correctly.

As mentioned by @wuxibin89 , we should use torchdata.stateful_dataloader.sampler instead of torch.utils.data.RandomSampler
CleanShot 2025-11-17 at 18 54 56@2x

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ……
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a reproducibility issue when resuming training by correctly restoring the sampler's state. The changes correctly calculate the epoch to resume from and modify the training loop to start at that epoch. However, the implementation for restoring the sampler's state is highly inefficient and contains a correctness bug. My review includes a critical comment with a suggested code change to address both the inefficiency by using set_epoch where available, and to fix an off-by-one error in the loop logic that would cause incorrect data sampling upon resuming.

Comment on lines 980 to 982
for _ in range(current_epoch - 1):
for _ in iter(self.train_dataloader.sampler):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This approach to restoring the sampler state is very inefficient and seems to have a correctness issue.

  1. Correctness: The loop for _ in range(current_epoch - 1): appears to have an off-by-one error. If training resumes at a global_step corresponding to the start of current_epoch=1, this loop (range(0)) will not execute. The main training loop will then start at epoch 1, but the sampler will still be in its state for epoch 0. This would lead to incorrect data sampling and break the reproducibility you're trying to fix. To correctly advance the sampler to be ready for current_epoch, the loop should run current_epoch times.

  2. Inefficiency: The nested loop iterates through all the samples in the dataset for each epoch you want to skip. For large datasets, this can add a significant delay to your training startup time.

A more robust and efficient solution would be to use the set_epoch() method if the sampler supports it (which is standard for torch.utils.data.distributed.DistributedSampler). This avoids iterating through the dataset entirely.

I suggest replacing this block with a more efficient and correct implementation that handles both points.

Suggested change
for _ in range(current_epoch - 1):
for _ in iter(self.train_dataloader.sampler):
pass
if hasattr(self.train_dataloader.sampler, "set_epoch"):
# More efficient and standard way for DistributedSampler
self.train_dataloader.sampler.set_epoch(current_epoch)
else:
# Fallback for samplers without set_epoch, with correction.
# This is inefficient and should be avoided if possible.
for _ in range(current_epoch):
for _ in self.train_dataloader.sampler:
pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address gemini review, set_epoch is more efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler appears to be one epoch ahead after resuming from _load_checkpoint (e.g., it resumes at step 3 instead of step 1). As a workaround, we only advance the sampler current_epoch - 1 times.

Copy link
Contributor Author

@wlhgtc wlhgtc Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address gemini review, set_epoch is more efficient.
@wuxibin89 Sorry for the miss.
For the inefficiency problem, the only sampler that currently matches the condition hasattr(self.train_dataloader.sampler, "set_epoch") is DistributedSampler. We do not use it in create_rl_sampler. This will only have an effect if we call set_epoch() between epochs, which we are not doing currently.

For Correctness Problem, you could see the comment below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's torch.utils.data.RandomSampler duty to skip sample consumed in this epoch so far, there's no need to skip it manually.
https://github.com/meta-pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L73-L74

Copy link
Contributor Author

@wlhgtc wlhgtc Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/meta-pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L73-L74

Sorry for miss this comment again(magic network).
We used torch.utils.data.sampler.RandomSampler rather than torchdata.stateful_dataloader.sampler.RandomSampler. I am trying the second one, seems it will fit StatefulDataloader

@wlhgtc
Copy link
Contributor Author

wlhgtc commented Nov 17, 2025

@eric-haibin-lin @vermouth1992 @PeterSH6 @tongyx361 Could anyone help me review this?

next_step_profile = False

for epoch in range(self.config.trainer.total_epochs):
for epoch in range(current_epoch, self.config.trainer.total_epochs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler

In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

I think, we miss set_epoch() here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works only when we use DistributedSampler, should I change sampler from RandomSampler/SequentialSampler to DistributedSampler?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it works for both RandomSampler/SequentialSampler and DistributedSampler, see comment above.

Copy link
Contributor Author

@wlhgtc wlhgtc Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler

You could see the code of DistributedSampler; it set generator by g.manual_seed(self.seed + self.epoch)

But in RandomSampler, it only based on seed, this sampler doesn't has the property(self.epoch)

In verl colde: creat_rl_sampler get generator(seed) for data_config.

@wuxibin89
Copy link
Collaborator

@wlhgtc I think we should use torchdata.stateful_dataloader.sampler.RandomSampler instead of torch.utils.data.RandomSampler

import os
import torch

from torch.utils.data import TensorDataset
from torchdata.stateful_dataloader import StatefulDataLoader

if os.environ.get("USE_TORCHDATA", "0") == "1":
    from torchdata.stateful_dataloader.sampler import RandomSampler
    print(f"Using torchdata.stateful_dataloader.sampler.RandomSampler")
else:
    from torch.utils.data import RandomSampler
    print(f"Using torch.utils.data.RandomSampler")

# Set seed for reproducibility
torch.manual_seed(42)

dataset = TensorDataset(torch.arange(10))
sampler = RandomSampler(dataset)
dataloader = StatefulDataLoader(dataset, num_workers=2, batch_size=4, drop_last=True, sampler=sampler)

global_step = 0
total_epoch = 3

for epoch in range(total_epoch):
    should_break = False
    for i, batch in enumerate(dataloader):
        global_step += 1
        print(f"{epoch=}, {i=}, {global_step=}: {batch}")

        # Save state_dict at step 0 of epoch 1
        if epoch == 1 and i == 0:
            state_dict = dataloader.state_dict()
            should_break = True
            break

    if should_break:
        break

start_epoch = global_step // len(dataloader)
print(f"recover from epoch {start_epoch}")

del sampler, dataloader

# Training run resumes with the previous checkpoint
sampler2 = RandomSampler(dataset)
dataloader2 = StatefulDataLoader(dataset, num_workers=2, batch_size=4, drop_last=True, sampler=sampler2)

# Resume state with DataLoader
dataloader2.load_state_dict(state_dict)
for epoch in range(start_epoch, total_epoch):
    for i, batch in enumerate(dataloader2):
        global_step += 1
        print(f"{epoch=}, {i=}, {global_step=}: {batch}")

The torch.utils.data.RandomSampler generator is not being recovered properly:

<Trial 66568621 worker_0> open_verl [wuxibin/rl_model_engine] $ python3 test_dataloader.py 
Using torch.utils.data.RandomSampler
epoch=0, i=0, global_step=1: [tensor([8, 0, 9, 1])]
epoch=0, i=1, global_step=2: [tensor([4, 5, 6, 3])]
epoch=1, i=0, global_step=3: [tensor([4, 6, 9, 3])]
recover from epoch 1
epoch=1, i=0, global_step=4: [tensor([3, 8, 5, 9])]
epoch=2, i=0, global_step=5: [tensor([4, 3, 8, 9])]
epoch=2, i=1, global_step=6: [tensor([2, 0, 5, 1])]
<Trial 66568621 worker_0> open_verl [wuxibin/rl_model_engine] $ USE_TORCHDATA=1 python3 test_dataloader.py 
Using torchdata.stateful_dataloader.sampler.RandomSampler
epoch=0, i=0, global_step=1: [tensor([8, 3, 0, 1])]
epoch=0, i=1, global_step=2: [tensor([7, 2, 9, 5])]
epoch=1, i=0, global_step=3: [tensor([1, 8, 2, 7])]
recover from epoch 1
epoch=1, i=0, global_step=4: [tensor([5, 0, 9, 3])]
epoch=2, i=0, global_step=5: [tensor([6, 3, 8, 0])]
epoch=2, i=1, global_step=6: [tensor([1, 5, 4, 7])]

@wlhgtc
Copy link
Contributor Author

wlhgtc commented Nov 17, 2025

@wlhgtc I think we should use torchdata.stateful_dataloader.sampler.RandomSampler instead of torch.utils.data.RandomSampler

import os
import torch

from torch.utils.data import TensorDataset
from torchdata.stateful_dataloader import StatefulDataLoader

if os.environ.get("USE_TORCHDATA", "0") == "1":
    from torchdata.stateful_dataloader.sampler import RandomSampler
    print(f"Using torchdata.stateful_dataloader.sampler.RandomSampler")
else:
    from torch.utils.data import RandomSampler
    print(f"Using torch.utils.data.RandomSampler")

# Set seed for reproducibility
torch.manual_seed(42)

dataset = TensorDataset(torch.arange(10))
sampler = RandomSampler(dataset)
dataloader = StatefulDataLoader(dataset, num_workers=2, batch_size=4, drop_last=True, sampler=sampler)

global_step = 0
total_epoch = 3

for epoch in range(total_epoch):
    should_break = False
    for i, batch in enumerate(dataloader):
        global_step += 1
        print(f"{epoch=}, {i=}, {global_step=}: {batch}")

        # Save state_dict at step 0 of epoch 1
        if epoch == 1 and i == 0:
            state_dict = dataloader.state_dict()
            should_break = True
            break

    if should_break:
        break

start_epoch = global_step // len(dataloader)
print(f"recover from epoch {start_epoch}")

del sampler, dataloader

# Training run resumes with the previous checkpoint
sampler2 = RandomSampler(dataset)
dataloader2 = StatefulDataLoader(dataset, num_workers=2, batch_size=4, drop_last=True, sampler=sampler2)

# Resume state with DataLoader
dataloader2.load_state_dict(state_dict)
for epoch in range(start_epoch, total_epoch):
    for i, batch in enumerate(dataloader2):
        global_step += 1
        print(f"{epoch=}, {i=}, {global_step=}: {batch}")

The torch.utils.data.RandomSampler generator is not being recovered properly:

<Trial 66568621 worker_0> open_verl [wuxibin/rl_model_engine] $ python3 test_dataloader.py 
Using torch.utils.data.RandomSampler
epoch=0, i=0, global_step=1: [tensor([8, 0, 9, 1])]
epoch=0, i=1, global_step=2: [tensor([4, 5, 6, 3])]
epoch=1, i=0, global_step=3: [tensor([4, 6, 9, 3])]
recover from epoch 1
epoch=1, i=0, global_step=4: [tensor([3, 8, 5, 9])]
epoch=2, i=0, global_step=5: [tensor([4, 3, 8, 9])]
epoch=2, i=1, global_step=6: [tensor([2, 0, 5, 1])]
<Trial 66568621 worker_0> open_verl [wuxibin/rl_model_engine] $ USE_TORCHDATA=1 python3 test_dataloader.py 
Using torchdata.stateful_dataloader.sampler.RandomSampler
epoch=0, i=0, global_step=1: [tensor([8, 3, 0, 1])]
epoch=0, i=1, global_step=2: [tensor([7, 2, 9, 5])]
epoch=1, i=0, global_step=3: [tensor([1, 8, 2, 7])]
recover from epoch 1
epoch=1, i=0, global_step=4: [tensor([5, 0, 9, 3])]
epoch=2, i=0, global_step=5: [tensor([6, 3, 8, 0])]
epoch=2, i=1, global_step=6: [tensor([1, 5, 4, 7])]

Thanks for your advice, I update my code and any other suggestions? @wuxibin89

@wuxibin89 wuxibin89 changed the title [fix] fix reproducible problem when resume training [trainer] fix: reproducible problem when resume training Nov 17, 2025
@wuxibin89
Copy link
Collaborator

@wlhgtc Could you please verify the change on your task?

@wuxibin89 wuxibin89 merged commit ef44dcc into verl-project:main Nov 17, 2025
80 of 84 checks passed
@wlhgtc
Copy link
Contributor Author

wlhgtc commented Nov 17, 2025

@wlhgtc Could you please verify the change on your task?

It worked.
image

chenhaiq pushed a commit to The-Hierophant/verl-1 that referenced this pull request Nov 18, 2025
…t#4156)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

Fix problem in verl-project#3457
When resuming training, we didn't calculate the correct epoch number to
restart from. The StatefulDataLoader only saves the training step(in a
epoch), not the specific step within a particular epoch.

This caused the sampler to always default to epoch 0 upon resuming,
which breaks reproducibility.

This PR fixes this by making two changes:

1. Calculating and restoring the correct current epoch number.
2. Restoring the sampler's state correctly.

As mentioned by @wuxibin89 , we should use
`torchdata.stateful_dataloader.sampler` instead of
`torch.utils.data.RandomSampler`
<img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54
56@2x"
src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5"
/>



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ……
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ……
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
wuwendyy pushed a commit to wuwendyy/verl that referenced this pull request Nov 19, 2025
…t#4156)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

Fix problem in verl-project#3457
When resuming training, we didn't calculate the correct epoch number to
restart from. The StatefulDataLoader only saves the training step(in a
epoch), not the specific step within a particular epoch.

This caused the sampler to always default to epoch 0 upon resuming,
which breaks reproducibility.

This PR fixes this by making two changes:

1. Calculating and restoring the correct current epoch number.
2. Restoring the sampler's state correctly.

As mentioned by @wuxibin89 , we should use
`torchdata.stateful_dataloader.sampler` instead of
`torch.utils.data.RandomSampler`
<img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54
56@2x"
src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5"
/>



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ……
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ……
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
…t#4156)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

Fix problem in verl-project#3457
When resuming training, we didn't calculate the correct epoch number to
restart from. The StatefulDataLoader only saves the training step(in a
epoch), not the specific step within a particular epoch.

This caused the sampler to always default to epoch 0 upon resuming,
which breaks reproducibility.

This PR fixes this by making two changes:

1. Calculating and restoring the correct current epoch number.
2. Restoring the sampler's state correctly.

As mentioned by @wuxibin89 , we should use
`torchdata.stateful_dataloader.sampler` instead of
`torch.utils.data.RandomSampler`
<img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54
56@2x"
src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5"
/>



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ……
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ……
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
…t#4156)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

Fix problem in verl-project#3457
When resuming training, we didn't calculate the correct epoch number to
restart from. The StatefulDataLoader only saves the training step(in a
epoch), not the specific step within a particular epoch.

This caused the sampler to always default to epoch 0 upon resuming,
which breaks reproducibility.

This PR fixes this by making two changes:

1. Calculating and restoring the correct current epoch number.
2. Restoring the sampler's state correctly.

As mentioned by @wuxibin89 , we should use
`torchdata.stateful_dataloader.sampler` instead of
`torch.utils.data.RandomSampler`
<img width="2916" height="1348" alt="CleanShot 2025-11-17 at 18 54
56@2x"
src="https://github.com/user-attachments/assets/536b0e8b-d530-4775-a421-c8e69ea9e6a5"
/>



### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ……
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ……
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: wanglebj02 <wanglebj02@kanyun.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants