Skip to content

Commit 10fd9ba

Browse files
committed
[feat] add validation shuffle (verl-project#1886)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? In scenarios involving multiple validation sets, where the difficulty levels of these sets differ significantly and the generated content lengths vary notably, the order in which the validation sets are processed can have a substantial impact on the validation speed. ### High-Level Design add validation shuffle ### Usage Example > Provide usage example(s) for easier usage. ```python validation_shuffle: True ``` ### Test Validation speed increase of over 10%. ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] New CI unit test(s) are added to cover the code path. - [ ] Rely on existing unit tests on CI that covers the code path.
1 parent 4858ae4 commit 10fd9ba

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

verl/trainer/config/ppo_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ data:
1313
return_raw_chat: False
1414
return_full_prompt: False
1515
shuffle: True
16+
validation_shuffle: False
1617
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
1718
filter_overlong_prompts_workers: 1
1819
truncation: error

verl/trainer/ppo/ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
550550
dataset=self.val_dataset,
551551
batch_size=val_batch_size,
552552
num_workers=self.config.data.get("dataloader_num_workers", 8),
553-
shuffle=False,
553+
shuffle=self.config.data.get("validation_shuffle", True),
554554
drop_last=False,
555555
collate_fn=collate_fn,
556556
)

0 commit comments

Comments
 (0)