diff --git a/docs/amd_tutorial/amd_build_dockerfile_page.rst b/docs/amd_tutorial/amd_build_dockerfile_page.rst index a7034f6d7bb..50ac1985a67 100644 --- a/docs/amd_tutorial/amd_build_dockerfile_page.rst +++ b/docs/amd_tutorial/amd_build_dockerfile_page.rst @@ -48,6 +48,7 @@ docker/Dockerfile.rocm liger-kernel \ numpy \ pandas \ + datasets \ peft \ "pyarrow>=15.0.0" \ pylatexenc \ diff --git a/docs/start/install.rst b/docs/start/install.rst index bf59654c43d..37fcf71f599 100644 --- a/docs/start/install.rst +++ b/docs/start/install.rst @@ -167,6 +167,7 @@ Find the docker for AMD ROCm: `docker/Dockerfile.rocm =15.0.0" \ pylatexenc \ diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 26531dc335f..27ece75e4b5 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -184,7 +184,8 @@ def _create_dataloader(self): filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error', - filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False)) + filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False), + num_workers=self.config.data.get('filter_overlong_prompts_workers', None)) # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() @@ -207,7 +208,8 @@ def _create_dataloader(self): filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation='error', - filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False)) + filter_overlong_prompts=self.config.data.get('filter_overlong_prompts', False), + num_workers=self.config.data.get('filter_overlong_prompts_workers', None)) self.val_dataloader = DataLoader(dataset=self.val_dataset, batch_size=len(self.val_dataset), shuffle=True, diff --git a/requirements.txt b/requirements.txt index 667dd98ce9d..596ec4cde5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ hydra-core liger-kernel numpy pandas +datasets peft pyarrow>=15.0.0 pybind11 diff --git a/requirements_sglang.txt b/requirements_sglang.txt index e859e1c987e..150f9d7c18b 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -7,6 +7,7 @@ flash-attn hydra-core numpy pandas +datasets peft pyarrow>=15.0.0 pybind11 diff --git a/setup.py b/setup.py index d090e5fed47..3d71833763d 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ 'hydra-core', 'numpy', 'pandas', + 'datasets', 'peft', 'pyarrow>=15.0.0', 'pybind11', diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 4d493942886..9b16bd225bd 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -10,7 +10,8 @@ data: return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False shuffle: True - filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + filter_overlong_prompts_workers: 1 truncation: error actor_rollout_ref: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7dd270bbaa4..5241ea6d9cc 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -10,7 +10,8 @@ data: return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False shuffle: True - filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + filter_overlong_prompts_workers: 1 truncation: error image_key: images diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 115035a66e1..0548f134df2 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -403,7 +403,8 @@ def _create_dataloader(self): filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation=self.config.data.get('truncation', 'error'), - filter_overlong_prompts=self.config.data.filter_overlong_prompts) + filter_overlong_prompts=self.config.data.filter_overlong_prompts, + num_workers=self.config.data.get('filter_overlong_prompts_workers', None)) assert self.train_dataset.truncation == self.config.data.get( 'truncation', 'error' ), f'dataset truncation {self.train_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}' @@ -431,7 +432,8 @@ def _create_dataloader(self): filter_prompts=True, return_raw_chat=self.config.data.get('return_raw_chat', False), truncation=self.config.data.get('truncation', 'error'), - filter_overlong_prompts=self.config.data.filter_overlong_prompts) + filter_overlong_prompts=self.config.data.filter_overlong_prompts, + num_workers=self.config.data.get('filter_overlong_prompts_workers', None)) assert self.val_dataset.truncation == self.config.data.get( 'truncation', 'error' ), f'dataset truncation {self.val_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}' diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index b9872072224..fe97d3bd7a3 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -14,9 +14,9 @@ from omegaconf import ListConfig import os -from typing import List, Union, Optional +from typing import List, Union, Optional, Callable import copy -import pandas as pd +import datasets from collections import defaultdict import torch @@ -81,15 +81,16 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin] = None, - prompt_key='prompt', - image_key='images', - max_prompt_length=1024, + prompt_key: str = 'prompt', + image_key: str = 'images', + max_prompt_length: int = 1024, filter_prompts=True, - cache_dir='~/.cache/verl/rlhf', - chat_template_func=None, - return_raw_chat=False, - truncation='error', - filter_overlong_prompts=False): + cache_dir: str = '~/.cache/verl/rlhf', + chat_template_func: Optional[Callable] = None, + return_raw_chat: bool = False, + truncation: str = 'error', + filter_overlong_prompts: bool = False, + num_workers: Optional[int] = None): if not isinstance(parquet_files, (List, ListConfig)): parquet_files = [parquet_files] @@ -108,6 +109,10 @@ def __init__(self, self.chat_template_func = chat_template_func self.truncation = truncation self.filter_overlong_prompts = filter_overlong_prompts + if num_workers is None: + self.num_workers = max(1, os.cpu_count() // 4) + else: + self.num_workers = min(num_workers, os.cpu_count()) # whether to store the dataset in state_dict() # default not store @@ -125,9 +130,9 @@ def _read_files_and_tokenize(self): dataframes = [] for parquet_file in self.parquet_files: # read parquet files and cache - dataframe = pd.read_parquet(parquet_file) + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] dataframes.append(dataframe) - self.dataframe = pd.concat(dataframes) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) print(f'dataset len: {len(self.dataframe)}') @@ -135,9 +140,11 @@ def _read_files_and_tokenize(self): if self.filter_overlong_prompts: tokenizer = self.tokenizer prompt_key = self.prompt_key - self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len( - tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length, - axis=1)] + self.dataframe = self.dataframe.filter( + lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) + ) <= self.max_prompt_length, + num_proc=self.num_workers, + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens") print(f'filter dataset len: {len(self.dataframe)}') @@ -157,7 +164,7 @@ def __getitem__(self, item): """ Note that we also return the raw_input_ids so that it can be combined with other chat template """ - row_dict: dict = self.dataframe.iloc[item].to_dict() + row_dict: dict = self.dataframe[item] chat = row_dict.pop(self.prompt_key) @@ -214,7 +221,7 @@ def __getitem__(self, item): # encode prompts without chat template if self.return_raw_chat: - row_dict['raw_prompt'] = chat.tolist() + row_dict['raw_prompt'] = chat # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0)