Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4a4dbc0
add support for custom datagen class that allows for adding new data …
jwong8314 Jul 1, 2025
9d599ce
ruff
jwong8314 Jul 1, 2025
d91b626
ruff-format
jwong8314 Jul 1, 2025
84b3815
ruff-format
jwong8314 Jul 1, 2025
3c1cf80
Update license
jwong8314 Jul 1, 2025
9c65168
update license
jwong8314 Jul 1, 2025
8a04aca
fix: make sure if there's not data_generatore it doesn't crash
jwong8314 Jul 1, 2025
87b89d0
ruff-format
jwong8314 Jul 1, 2025
732b184
Merge branch 'main' into main
zhaochenyang20 Jul 2, 2025
ffba50d
Merge branch 'main' into dynamic_dataset
jwong8314 Jul 4, 2025
c620bcb
undo change to import_utils
jwong8314 Jul 4, 2025
6b061f9
merging into dataset
jwong8314 Jul 4, 2025
13debde
Merge pull request #1 from jwong8314/dynamic_dataset
jwong8314 Jul 4, 2025
19201e2
rename variables
jwong8314 Jul 4, 2025
72b223e
is_train rename
jwong8314 Jul 4, 2025
383cf61
Merge pull request #2 from jwong8314/dynamic_dataset
jwong8314 Jul 4, 2025
5070088
rename
jwong8314 Jul 4, 2025
3250d1d
rename to Generator
jwong8314 Jul 4, 2025
8c48f09
Merge pull request #3 from jwong8314/dynamic_dataset
jwong8314 Jul 4, 2025
0a5cabf
Merge branch 'main' into main
zhaochenyang20 Jul 4, 2025
4aac878
add parameter for batch information
jwong8314 Jul 8, 2025
e3bbd57
add comments and placed files in experimental
jwong8314 Jul 8, 2025
122e817
move to experimental subdir
jwong8314 Jul 8, 2025
a74ff75
ruff
jwong8314 Jul 8, 2025
2e44ead
ruff
jwong8314 Jul 8, 2025
6126b96
Merge branch 'volcengine:main' into main
jwong8314 Jul 8, 2025
16647d4
patch CI
jwong8314 Jul 8, 2025
171c9be
Merge branch 'main' into main
jwong8314 Jul 9, 2025
314d350
resolve conflicts new yaml
jwong8314 Jul 9, 2025
4bd1452
typo
jwong8314 Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/special_sanity/check_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
license_head_individual = "Copyright 2025 Individual Contributor:"
license_head_sglang = "Copyright 2023-2024 SGLang Team"
license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates"
license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates"
license_headers = [
license_head_bytedance,
license_head_bytedance_25,
license_head_prime,
license_head_individual,
license_head_sglang,
license_head_modelbest,
license_head_amazon,
]


Expand Down
11 changes: 11 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ data:
# The name of the dataset class within the specified file.
name: null

# Data generation configuration for augmenting the dataset.
datagen:

# The path to the file containing your customized data generation class.
# E.g. 'pkg://verl.utils.dataset.dynamicgen_dataset'
path: null

# The class name of the data generation class within the specified file.
# E.g. 'MockDataGenerator'
name: null

# config for actor, rollout and reference model
actor_rollout_ref:

Expand Down
13 changes: 10 additions & 3 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def run(self, config):
from verl.utils.dataset.rl_dataset import collate_fn

# Create training and validation datasets.
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False)
train_sampler = create_rl_sampler(config.data, train_dataset)

# Initialize the PPO trainer.
Expand All @@ -214,7 +214,7 @@ def run(self, config):
trainer.fit()


def create_rl_dataset(data_paths, data_config, tokenizer, processor):
def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True):
"""Create a dataset.

Arguments:
Expand Down Expand Up @@ -243,6 +243,13 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor):
f"The custom dataset class '{data_config.custom_cls.name}' from "
f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset"
)
elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train:
# If a data generation strategy is specified, use the DynamicGenDataset class
from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset

dataset_cls = DynamicGenDataset
print("Using DynamicGenDataset for data generation.")

else:
# Use the default RLHFDataset class if no custom class is specified
dataset_cls = RLHFDataset
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,3 +1367,7 @@ def fit(self):
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return

if hasattr(self.train_dataset, "on_batch_end"):
# The dataset may be changed after each training batch
self.train_dataset.on_batch_end()
108 changes: 108 additions & 0 deletions verl/utils/dataset/dynamicgen_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2025 Amazon.com Inc and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import datasets
from omegaconf import DictConfig
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.dataset import RLHFDataset
from verl.utils.import_utils import load_extern_type

logger = logging.getLogger(__name__)


class AbstractDataGenerator(ABC):
def __init__(self, config: DictConfig):
self.config = config

@abstractmethod
def generate(self, dataset: Dataset) -> datasets.Dataset:
"""
Generate method must be implemented by subclasses.
Args:
dataset: The dataset to generate from.
Returns:
Processed data or result as implemented by the subclass.
"""
pass


class MockDataGenerator(AbstractDataGenerator):
"""
A noop data gen class that only reappends the first datapoint.
This class is useful as a placeholder and testing.
"""

def __init__(self, config: DictConfig = None):
super().__init__(config)

def generate(self, dataset: Dataset) -> datasets.Dataset:
print("MockDataGenerator: No operation performed on the dataset.")
return dataset.dataframe.select([0])


class DynamicGenDataset(RLHFDataset):
"""
A dataset class that uses a data generation strategy to process data.
This class extends RLHFDataset and uses an AbstractDataGen instance to generate data.
"""

def __init__(
self,
data_files: Union[str, List[str]],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,
):
super().__init__(data_files, tokenizer, config, processor)
self.datagen: AbstractDataGenerator = config.datagen
assert "datagen" in config and config.datagen.get("path", None) is not None, (
f"datagen path is not set in config: {config}"
)
# Dynamically load the custom datagen class
datagen_cls = load_extern_type(config.datagen.path, config.datagen.name)

# Verify that the custom datagen class inherits from AbstractDataGenerator
abs_cls = AbstractDataGenerator
if not issubclass(datagen_cls, abs_cls):
raise TypeError(
f"The custom datagen class '{config.datagen.name}' from '{config.datagen.path}'"
+ " must inherit from {abs_cls}"
)

self.data_generator = datagen_cls(config.datagen)
self.on_batch_end()

def append_dataframe(self, new_dataframe: datasets.Dataset):
new_dataframe = self.maybe_filter_out_long_prompts(new_dataframe)
self.dataframe = datasets.concatenate_datasets([self.dataframe, new_dataframe])

logger.info(f"new dataset len: {len(self.dataframe)}")

def on_batch_end(self) -> None:
"""
Generate data using the provided data generation strategy.
Note: This method is intended to change the dataset after each training batch.
"""
new_data = self.data_generator.generate(self)
self.append_dataframe(new_data)
8 changes: 6 additions & 2 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def _read_files_and_tokenize(self):

print(f"dataset len: {len(self.dataframe)}")

self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe)

def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None):
# filter out too long prompts
if self.filter_overlong_prompts:
tokenizer = self.tokenizer
Expand Down Expand Up @@ -165,13 +168,14 @@ def doc2len(doc) -> int:
def doc2len(doc) -> int:
return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))

self.dataframe = self.dataframe.filter(
dataframe = dataframe.filter(
lambda doc: doc2len(doc) <= self.max_prompt_length,
num_proc=self.num_workers,
desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
)

print(f"filter dataset len: {len(self.dataframe)}")
print(f"filter dataset len: {len(dataframe)}")
return dataframe

def resume_dataset_state(self):
self.serialize_dataset = not hasattr(self, "original_data_files")
Expand Down