Skip to content

Commit cccc2ef

Browse files
[cfg] refactor: make the rollout & ref configs more modular (#2410)
### What does this PR do? move rollout and ref configs to standalone files. cc @ETOgaosion for dp_ref/rollout, default values are added to the yaml if actor_rollout_ref.actor does not exist, so that the yaml can be loaded independently. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test Relying on existing tests. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
1 parent ad33564 commit cccc2ef

File tree

12 files changed

+514
-544
lines changed

12 files changed

+514
-544
lines changed

tests/special_sanity/test_config_docs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def test_trainer_config_doc():
6262
"verl/trainer/config/ppo_trainer.yaml",
6363
"verl/trainer/config/actor/actor.yaml",
6464
"verl/trainer/config/actor/dp_actor.yaml",
65+
"verl/trainer/config/ref/ref.yaml",
66+
"verl/trainer/config/ref/dp_ref.yaml",
67+
"verl/trainer/config/rollout/rollout.yaml",
6568
]
6669
success = True
6770
for yaml_to_inspect in yamls_to_inspect:

tests/trainer/config/test_legacy_config_on_cpu.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import unittest
1617

18+
from hydra import compose, initialize_config_dir
19+
from hydra.core.global_hydra import GlobalHydra
1720
from omegaconf import OmegaConf
1821

1922

2023
class TestConfigComparison(unittest.TestCase):
2124
"""Test that current configs match their legacy counterparts exactly."""
2225

23-
def _compare_configs_recursively(self, current_config, legacy_config, path=""):
24-
"""Recursively compare two OmegaConf configs and assert they are identical."""
26+
def _compare_configs_recursively(self, current_config, legacy_config, path="", legacy_allow_missing=False):
27+
"""Recursively compare two OmegaConf configs and assert they are identical.
28+
29+
Args:
30+
legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and
31+
we allow that to happen
32+
"""
2533
if isinstance(current_config, dict) and isinstance(legacy_config, dict):
2634
current_keys = set(current_config.keys())
2735
legacy_keys = set(legacy_config.keys())
@@ -32,19 +40,29 @@ def _compare_configs_recursively(self, current_config, legacy_config, path=""):
3240
if missing_in_current:
3341
self.fail(f"Keys missing in current config at {path}: {missing_in_current}")
3442
if missing_in_legacy:
35-
self.fail(f"Keys missing in legacy config at {path}: {missing_in_legacy}")
43+
# if the legacy
44+
msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}"
45+
if legacy_allow_missing:
46+
print(msg)
47+
else:
48+
self.fail(msg)
3649

3750
for key in current_keys:
3851
current_path = f"{path}.{key}" if path else key
39-
self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)
52+
if key in legacy_config:
53+
self._compare_configs_recursively(
54+
current_config[key], legacy_config[key], current_path, legacy_allow_missing=legacy_allow_missing
55+
)
4056
elif isinstance(current_config, list) and isinstance(legacy_config, list):
4157
self.assertEqual(
4258
len(current_config),
4359
len(legacy_config),
4460
f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
4561
)
4662
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config)):
47-
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
63+
self._compare_configs_recursively(
64+
current_item, legacy_item, f"{path}[{i}]", legacy_allow_missing=legacy_allow_missing
65+
)
4866
else:
4967
self.assertEqual(
5068
current_config,
@@ -66,7 +84,6 @@ def test_ppo_trainer_config_matches_legacy(self):
6684
current_config = compose(config_name="ppo_trainer")
6785

6886
legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml")
69-
7087
current_dict = OmegaConf.to_container(current_config, resolve=True)
7188
legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
7289

@@ -79,29 +96,42 @@ def test_ppo_trainer_config_matches_legacy(self):
7996

8097
def test_ppo_megatron_trainer_config_matches_legacy(self):
8198
"""Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
82-
import os
83-
84-
from hydra import compose, initialize_config_dir
85-
from hydra.core.global_hydra import GlobalHydra
8699

87100
GlobalHydra.instance().clear()
88101

89102
try:
90-
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config"), version_base=None):
103+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
91104
current_config = compose(config_name="ppo_megatron_trainer")
92105

93106
legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml")
94-
95107
current_dict = OmegaConf.to_container(current_config, resolve=True)
96108
legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
97109

98110
if "defaults" in current_dict:
99111
del current_dict["defaults"]
100112

101-
self._compare_configs_recursively(current_dict, legacy_dict)
113+
self._compare_configs_recursively(current_dict, legacy_dict, legacy_allow_missing=True)
102114
finally:
103115
GlobalHydra.instance().clear()
104116

117+
def test_load_component(self):
118+
"""Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
119+
120+
GlobalHydra.instance().clear()
121+
configs_to_load = [
122+
("verl/trainer/config/actor", "dp_actor"),
123+
("verl/trainer/config/actor", "megatron_actor"),
124+
("verl/trainer/config/ref", "dp_ref"),
125+
("verl/trainer/config/ref", "megatron_ref"),
126+
("verl/trainer/config/rollout", "rollout"),
127+
]
128+
for config_dir, config_file in configs_to_load:
129+
try:
130+
with initialize_config_dir(config_dir=os.path.abspath(config_dir)):
131+
compose(config_name=config_file)
132+
finally:
133+
GlobalHydra.instance().clear()
134+
105135

106136
if __name__ == "__main__":
107137
unittest.main()

verl/trainer/config/actor/actor.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ ppo_micro_batch_size: null
1818
ppo_micro_batch_size_per_gpu: null
1919

2020
# Whether to automatically adjust batch size at runtime
21+
# oc.select: the default val for ref.log_prob_use_dynamic_bsz
2122
use_dynamic_bsz: false
2223

2324
# Max tokens per GPU in one PPO batch; affects gradient accumulation
2425
# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}
26+
# oc.select: the default val for ref.log_prob_max_token_len_per_gpu
2527
ppo_max_token_len_per_gpu: 16384
2628

2729
# PPO clip ratio
@@ -67,6 +69,7 @@ entropy_coeff: 0
6769
use_kl_loss: false
6870

6971
# Whether to use torch.compile()
72+
# oc.select: the default val for ref.use_torch_compile
7073
use_torch_compile: true
7174

7275
# KL loss coefficient when use_kl_loss is enabled. For GRPO
@@ -89,7 +92,8 @@ checkpoint:
8992
save_contents: ['model', 'optimizer', 'extra']
9093

9194
# For more flexibility, you can specify the contents to load from the checkpoint.
92-
load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
95+
# .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg
96+
load_contents: ${.save_contents}
9397

9498
# optimizer configs
9599
optim:

verl/trainer/config/actor/dp_actor.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ strategy: fsdp
2020
grad_clip: 1.0
2121

2222
# Sequence parallelism size for Ulysses-style model parallelism
23+
# oc.select: the default val for ref.ulysses_sequence_parallel_size
2324
ulysses_sequence_parallel_size: 1
2425

2526
# calculate entropy with chunking to reduce memory peak

verl/trainer/config/actor/megatron_actor.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,14 @@ megatron:
7373

7474
dist_checkpointing_path: null
7575

76+
# oc.select: default val for ref.megatron.seed
7677
seed: 42
7778

7879
# additional transformer config like: num_layers_in_first(/last)_pipeline_stage
80+
# oc.select: default val for ref.megatron.override_transformer_config
7981
override_transformer_config: {}
8082

83+
# oc.select: default val for ref.megatron.use_mbridge
8184
use_mbridge: False
8285

8386
# profile the actor model in `update_policy`
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Tokenizer class or path. If null, it will be inferred from the model.
2+
tokenizer: null
3+
4+
# Whether to use shared memory for data loading.
5+
use_shm: False
6+
7+
# Training set parquet. Can be a list or a single file.
8+
# The program will read all files into memory, so it can't be too large (< 100GB).
9+
# The path can be either a local path or an HDFS path.
10+
# For HDFS path, we provide utils to download it to DRAM and convert it to a local path.
11+
train_files: ~/data/rlhf/gsm8k/train.parquet
12+
13+
# Validation parquet. Can be a list or a single file.
14+
val_files: ~/data/rlhf/gsm8k/test.parquet
15+
16+
# The field in the dataset where the prompt is located. Default is 'prompt'.
17+
prompt_key: prompt
18+
19+
# The field used to select the reward function (if using different ones per example).
20+
reward_fn_key: data_source
21+
22+
# Maximum prompt length. All prompts will be left-padded to this length.
23+
# An error will be reported if the length is too long.
24+
# oc.select: default val for rollout.prompt_length
25+
max_prompt_length: 512
26+
27+
# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.
28+
# oc.select: default val for rollout.response_length
29+
max_response_length: 512
30+
31+
# Batch size sampled for one training iteration of different RL algorithms.
32+
train_batch_size: 1024
33+
34+
# Batch size used during validation. Can be null.
35+
val_batch_size: null
36+
37+
# Whether to return the original input_ids without adding chat template.
38+
# This is used when the reward model's chat template differs from the policy.
39+
# If using a model-based RM with different templates, this should be True.
40+
return_raw_input_ids: False
41+
42+
# Whether to return the original chat (prompt) without applying chat template.
43+
return_raw_chat: False
44+
45+
# Whether to return the full prompt with chat template.
46+
return_full_prompt: False
47+
48+
# Whether to shuffle the data in the dataloader.
49+
shuffle: True
50+
51+
# num dataloader workers
52+
dataloader_num_workers: 8
53+
54+
# Whether to shuffle the validation set.
55+
validation_shuffle: False
56+
57+
# Whether to filter overlong prompts.
58+
filter_overlong_prompts: False
59+
60+
# Number of workers for filtering overlong prompts.
61+
# For large-scale datasets, filtering can be time-consuming.
62+
# Use multiprocessing to speed up. Default is 1.
63+
filter_overlong_prompts_workers: 1
64+
65+
# Truncate the input_ids or prompt if they exceed max_prompt_length.
66+
# Options: 'error', 'left', or 'right'. Default is 'error'.
67+
truncation: error
68+
69+
# The field in the multi-modal dataset where the image is located. Default is 'images'.
70+
image_key: images
71+
72+
# The field in the multi-modal dataset where the video is located.
73+
video_key: videos
74+
75+
# If the remote tokenizer has a Python file, this flag determines whether to allow using it.
76+
trust_remote_code: False
77+
78+
# Optional: specify a custom dataset class path and name if overriding default loading behavior.
79+
custom_cls:
80+
81+
# The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.
82+
path: null
83+
84+
# The name of the dataset class within the specified file.
85+
name: null
86+
87+
# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.
88+
return_multi_modal_inputs: True
89+
90+
# settings related to data sampler
91+
sampler:
92+
93+
# the path to the module containing a curriculum class which implements the
94+
# AbstractSampler interface
95+
class_path: null
96+
97+
# the name of the curriculum class like `MySampler`
98+
class_name: null

0 commit comments

Comments
 (0)