From e0910d9117cea15df3b29b4cac01a96f20d6c924 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sun, 13 Apr 2025 03:54:30 +0000 Subject: [PATCH] fix: filter overlong prompts should also consider multi modal inputs also, if the passed train/val files is a directory, we should read parquet files using this dir. This makes RLHFDataset compatible with pd.read_parquet we used before. --- verl/utils/dataset/rl_dataset.py | 38 +++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 1be4ee98c9d..59534779e5c 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -121,7 +121,10 @@ def _read_files_and_tokenize(self): dataframes = [] for parquet_file in self.data_files: # read parquet files and cache - dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + if os.path.isdir(parquet_file): + dataframe = datasets.load_dataset("parquet", data_dir=parquet_file, split="train") + else: + dataframe = datasets.load_dataset("parquet", data_files=parquet_file, split="train") dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) @@ -129,11 +132,8 @@ def _read_files_and_tokenize(self): # filter out too long prompts if self.filter_overlong_prompts: - tokenizer = self.tokenizer - prompt_key = self.prompt_key self.dataframe = self.dataframe.filter( - lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) - ) <= self.max_prompt_length, + self._filter_overlong_prompts, num_proc=self.num_workers, desc=f"Filtering prompts longer than {self.max_prompt_length} tokens") @@ -148,6 +148,34 @@ def resume_dataset_state(self): else: print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance') + def _filter_overlong_prompts(self, example: dict) -> bool: + chat_messsages = example[self.prompt_key] + prompt_with_chat_template = self.tokenizer.apply_chat_template(chat_messsages, + add_generation_prompt=True, + tokenize=False) + if self.image_key in example and self.processor is not None: + # for multi modal, we need to process the prompt_with_chat_template + images = [process_image(image) for image in example[self.image_key]] + image_inputs = self.processor.image_processor(images, return_tensors='pt') + image_grid_thw = image_inputs['image_grid_thw'] + + if image_grid_thw is not None: + merge_length = self.processor.image_processor.merge_size**2 + index = 0 + while '' in prompt_with_chat_template: + prompt_with_chat_template = prompt_with_chat_template.replace( + '', + '<|vision_start|>' + '<|placeholder|>' * (image_grid_thw[index].prod() // merge_length) + + '<|vision_end|>', + 1, + ) + index += 1 + + prompt_with_chat_template = prompt_with_chat_template.replace('<|placeholder|>', + self.processor.image_token) + input_ids = self.tokenizer.encode(prompt_with_chat_template, add_special_tokens=False) + return len(input_ids) <= self.max_prompt_length + def __len__(self): return len(self.dataframe)