Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,19 @@ def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
# read parquet files and cache
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
if os.path.isdir(parquet_file):
dataframe = datasets.load_dataset("parquet", data_dir=parquet_file, split="train")
else:
dataframe = datasets.load_dataset("parquet", data_files=parquet_file, split="train")
dataframes.append(dataframe)
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

print(f'dataset len: {len(self.dataframe)}')

# filter out too long prompts
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
prompt_key = self.prompt_key
self.dataframe = self.dataframe.filter(
lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)
) <= self.max_prompt_length,
self._filter_overlong_prompts,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens")

Expand All @@ -148,6 +148,34 @@ def resume_dataset_state(self):
else:
print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance')

def _filter_overlong_prompts(self, example: dict) -> bool:
chat_messsages = example[self.prompt_key]
prompt_with_chat_template = self.tokenizer.apply_chat_template(chat_messsages,
add_generation_prompt=True,
tokenize=False)
if self.image_key in example and self.processor is not None:
# for multi modal, we need to process the prompt_with_chat_template
images = [process_image(image) for image in example[self.image_key]]
image_inputs = self.processor.image_processor(images, return_tensors='pt')
image_grid_thw = image_inputs['image_grid_thw']

if image_grid_thw is not None:
merge_length = self.processor.image_processor.merge_size**2
index = 0
while '<image>' in prompt_with_chat_template:
prompt_with_chat_template = prompt_with_chat_template.replace(
'<image>',
'<|vision_start|>' + '<|placeholder|>' * (image_grid_thw[index].prod() // merge_length) +
'<|vision_end|>',
1,
)
index += 1

prompt_with_chat_template = prompt_with_chat_template.replace('<|placeholder|>',
self.processor.image_token)
input_ids = self.tokenizer.encode(prompt_with_chat_template, add_special_tokens=False)
return len(input_ids) <= self.max_prompt_length

def __len__(self):
return len(self.dataframe)

Expand Down