Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,11 @@ def fit(self):

# load checkpoint before doing anything
self._load_checkpoint()
# resume sampler state if needed
current_epoch = self.global_steps // len(self.train_dataloader)
for _ in range(current_epoch - 1):
for _ in iter(self.train_dataloader.sampler):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This approach to restoring the sampler state is very inefficient and seems to have a correctness issue.

  1. Correctness: The loop for _ in range(current_epoch - 1): appears to have an off-by-one error. If training resumes at a global_step corresponding to the start of current_epoch=1, this loop (range(0)) will not execute. The main training loop will then start at epoch 1, but the sampler will still be in its state for epoch 0. This would lead to incorrect data sampling and break the reproducibility you're trying to fix. To correctly advance the sampler to be ready for current_epoch, the loop should run current_epoch times.

  2. Inefficiency: The nested loop iterates through all the samples in the dataset for each epoch you want to skip. For large datasets, this can add a significant delay to your training startup time.

A more robust and efficient solution would be to use the set_epoch() method if the sampler supports it (which is standard for torch.utils.data.distributed.DistributedSampler). This avoids iterating through the dataset entirely.

I suggest replacing this block with a more efficient and correct implementation that handles both points.

Suggested change
for _ in range(current_epoch - 1):
for _ in iter(self.train_dataloader.sampler):
pass
if hasattr(self.train_dataloader.sampler, "set_epoch"):
# More efficient and standard way for DistributedSampler
self.train_dataloader.sampler.set_epoch(current_epoch)
else:
# Fallback for samplers without set_epoch, with correction.
# This is inefficient and should be avoided if possible.
for _ in range(current_epoch):
for _ in self.train_dataloader.sampler:
pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address gemini review, set_epoch is more efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampler appears to be one epoch ahead after resuming from _load_checkpoint (e.g., it resumes at step 3 instead of step 1). As a workaround, we only advance the sampler current_epoch - 1 times.

Copy link
Contributor Author

@wlhgtc wlhgtc Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address gemini review, set_epoch is more efficient.
@wuxibin89 Sorry for the miss.
For the inefficiency problem, the only sampler that currently matches the condition hasattr(self.train_dataloader.sampler, "set_epoch") is DistributedSampler. We do not use it in create_rl_sampler. This will only have an effect if we call set_epoch() between epochs, which we are not doing currently.

For Correctness Problem, you could see the comment below.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's torch.utils.data.RandomSampler duty to skip sample consumed in this epoch so far, there's no need to skip it manually.
https://github.com/meta-pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L73-L74

Copy link
Contributor Author

@wlhgtc wlhgtc Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/meta-pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L73-L74

Sorry for miss this comment again(magic network).
We used torch.utils.data.sampler.RandomSampler rather than torchdata.stateful_dataloader.sampler.RandomSampler. I am trying the second one, seems it will fit StatefulDataloader


# perform validation before training
# currently, we only support validation using the reward_function.
Expand Down Expand Up @@ -1006,7 +1011,7 @@ def fit(self):
)
next_step_profile = False

for epoch in range(self.config.trainer.total_epochs):
for epoch in range(current_epoch, self.config.trainer.total_epochs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler

In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

I think, we miss set_epoch() here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works only when we use DistributedSampler, should I change sampler from RandomSampler/SequentialSampler to DistributedSampler?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it works for both RandomSampler/SequentialSampler and DistributedSampler, see comment above.

Copy link
Contributor Author

@wlhgtc wlhgtc Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler

You could see the code of DistributedSampler; it set generator by g.manual_seed(self.seed + self.epoch)

But in RandomSampler, it only based on seed, this sampler doesn't has the property(self.epoch)

In verl colde: creat_rl_sampler get generator(seed) for data_config.

for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
Expand Down