Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/amd_tutorial/amd_build_dockerfile_page.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ docker/Dockerfile.rocm
liger-kernel \
numpy \
pandas \
datasets \
peft \
"pyarrow>=15.0.0" \
pylatexenc \
Expand Down
1 change: 1 addition & 0 deletions docs/start/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ Find the docker for AMD ROCm: `docker/Dockerfile.rocm <https://github.com/volcen
liger-kernel \
numpy \
pandas \
datasets \
peft \
"pyarrow>=15.0.0" \
pylatexenc \
Expand Down
6 changes: 4 additions & 2 deletions recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def _create_dataloader(self):
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error',
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False))
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False),
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
Expand All @@ -207,7 +208,8 @@ def _create_dataloader(self):
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation='error',
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False))
filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False),
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ hydra-core
liger-kernel
numpy
pandas
datasets
peft
pyarrow>=15.0.0
pybind11
Expand Down
1 change: 1 addition & 0 deletions requirements_sglang.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ flash-attn
hydra-core
numpy
pandas
datasets
peft
pyarrow>=15.0.0
pybind11
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'hydra-core',
'numpy',
'pandas',
'datasets',
'peft',
'pyarrow>=15.0.0',
'pybind11',
Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ data:
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1
truncation: error

actor_rollout_ref:
Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ data:
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1
truncation: error
image_key: images

Expand Down
6 changes: 4 additions & 2 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def _create_dataloader(self):
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation=self.config.data.get('truncation', 'error'),
filter_overlong_prompts=self.config.data.filter_overlong_prompts)
filter_overlong_prompts=self.config.data.filter_overlong_prompts,
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
assert self.train_dataset.truncation == self.config.data.get(
'truncation', 'error'
), f'dataset truncation {self.train_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
Expand Down Expand Up @@ -431,7 +432,8 @@ def _create_dataloader(self):
filter_prompts=True,
return_raw_chat=self.config.data.get('return_raw_chat', False),
truncation=self.config.data.get('truncation', 'error'),
filter_overlong_prompts=self.config.data.filter_overlong_prompts)
filter_overlong_prompts=self.config.data.filter_overlong_prompts,
num_workers=self.config.data.get('filter_overlong_prompts_workers', None))
assert self.val_dataset.truncation == self.config.data.get(
'truncation', 'error'
), f'dataset truncation {self.val_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
Expand Down
41 changes: 24 additions & 17 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from omegaconf import ListConfig
import os
from typing import List, Union, Optional
from typing import List, Union, Optional, Callable
import copy
import pandas as pd
import datasets
from collections import defaultdict

import torch
Expand Down Expand Up @@ -81,15 +81,16 @@ def __init__(self,
parquet_files: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin] = None,
prompt_key='prompt',
image_key='images',
max_prompt_length=1024,
prompt_key: str = 'prompt',
image_key: str = 'images',
max_prompt_length: int = 1024,
filter_prompts=True,
cache_dir='~/.cache/verl/rlhf',
chat_template_func=None,
return_raw_chat=False,
truncation='error',
filter_overlong_prompts=False):
cache_dir: str = '~/.cache/verl/rlhf',
chat_template_func: Optional[Callable] = None,
return_raw_chat: bool = False,
truncation: str = 'error',
filter_overlong_prompts: bool = False,
num_workers: Optional[int] = None):
if not isinstance(parquet_files, (List, ListConfig)):
parquet_files = [parquet_files]

Expand All @@ -108,6 +109,10 @@ def __init__(self,
self.chat_template_func = chat_template_func
self.truncation = truncation
self.filter_overlong_prompts = filter_overlong_prompts
if num_workers is None:
self.num_workers = max(1, os.cpu_count() // 4)
else:
self.num_workers = min(num_workers, os.cpu_count())

# whether to store the dataset in state_dict()
# default not store
Expand All @@ -125,19 +130,21 @@ def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.parquet_files:
# read parquet files and cache
dataframe = pd.read_parquet(parquet_file)
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframes.append(dataframe)
self.dataframe = pd.concat(dataframes)
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

print(f'dataset len: {len(self.dataframe)}')

# filter out too long prompts
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
axis=1)]
self.dataframe = self.dataframe.filter(
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)
) <= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens")

print(f'filter dataset len: {len(self.dataframe)}')

Expand All @@ -157,7 +164,7 @@ def __getitem__(self, item):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict: dict = self.dataframe.iloc[item].to_dict()
row_dict: dict = self.dataframe[item]

chat = row_dict.pop(self.prompt_key)

Expand Down Expand Up @@ -214,7 +221,7 @@ def __getitem__(self, item):

# encode prompts without chat template
if self.return_raw_chat:
row_dict['raw_prompt'] = chat.tolist()
row_dict['raw_prompt'] = chat

# add index for each prompt
index = row_dict.get("extra_info", {}).get("index", 0)
Expand Down