Skip to content

Commit 171c9be

Browse files
authored
Merge branch 'main' into main
2 parents 16647d4 + cccc2ef commit 171c9be

File tree

23 files changed

+759
-575
lines changed

23 files changed

+759
-575
lines changed

.github/workflows/vllm.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ jobs:
105105
- name: Test the latest vLLM
106106
run: |
107107
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s tests/workers/rollout/rollout_vllm/test_vllm_spmd.py
108+
- name: Test the latest vLLM on model with rope scaling
109+
run: |
110+
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py
108111
- name: Run Qwen 0.5B generation test
109112
run: |
110113
cd tests/special_e2e/generation

tests/special_sanity/test_config_docs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def test_trainer_config_doc():
6262
"verl/trainer/config/ppo_trainer.yaml",
6363
"verl/trainer/config/actor/actor.yaml",
6464
"verl/trainer/config/actor/dp_actor.yaml",
65+
"verl/trainer/config/ref/ref.yaml",
66+
"verl/trainer/config/ref/dp_ref.yaml",
67+
"verl/trainer/config/rollout/rollout.yaml",
6568
]
6669
success = True
6770
for yaml_to_inspect in yamls_to_inspect:

tests/trainer/config/test_legacy_config_on_cpu.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import unittest
1617

18+
from hydra import compose, initialize_config_dir
19+
from hydra.core.global_hydra import GlobalHydra
1720
from omegaconf import OmegaConf
1821

1922

2023
class TestConfigComparison(unittest.TestCase):
2124
"""Test that current configs match their legacy counterparts exactly."""
2225

23-
def _compare_configs_recursively(self, current_config, legacy_config, path=""):
24-
"""Recursively compare two OmegaConf configs and assert they are identical."""
26+
def _compare_configs_recursively(self, current_config, legacy_config, path="", legacy_allow_missing=False):
27+
"""Recursively compare two OmegaConf configs and assert they are identical.
28+
29+
Args:
30+
legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and
31+
we allow that to happen
32+
"""
2533
if isinstance(current_config, dict) and isinstance(legacy_config, dict):
2634
current_keys = set(current_config.keys())
2735
legacy_keys = set(legacy_config.keys())
@@ -32,19 +40,29 @@ def _compare_configs_recursively(self, current_config, legacy_config, path=""):
3240
if missing_in_current:
3341
self.fail(f"Keys missing in current config at {path}: {missing_in_current}")
3442
if missing_in_legacy:
35-
self.fail(f"Keys missing in legacy config at {path}: {missing_in_legacy}")
43+
# if the legacy
44+
msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}"
45+
if legacy_allow_missing:
46+
print(msg)
47+
else:
48+
self.fail(msg)
3649

3750
for key in current_keys:
3851
current_path = f"{path}.{key}" if path else key
39-
self._compare_configs_recursively(current_config[key], legacy_config[key], current_path)
52+
if key in legacy_config:
53+
self._compare_configs_recursively(
54+
current_config[key], legacy_config[key], current_path, legacy_allow_missing=legacy_allow_missing
55+
)
4056
elif isinstance(current_config, list) and isinstance(legacy_config, list):
4157
self.assertEqual(
4258
len(current_config),
4359
len(legacy_config),
4460
f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}",
4561
)
4662
for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config)):
47-
self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]")
63+
self._compare_configs_recursively(
64+
current_item, legacy_item, f"{path}[{i}]", legacy_allow_missing=legacy_allow_missing
65+
)
4866
else:
4967
self.assertEqual(
5068
current_config,
@@ -66,7 +84,6 @@ def test_ppo_trainer_config_matches_legacy(self):
6684
current_config = compose(config_name="ppo_trainer")
6785

6886
legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml")
69-
7087
current_dict = OmegaConf.to_container(current_config, resolve=True)
7188
legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
7289

@@ -79,29 +96,42 @@ def test_ppo_trainer_config_matches_legacy(self):
7996

8097
def test_ppo_megatron_trainer_config_matches_legacy(self):
8198
"""Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
82-
import os
83-
84-
from hydra import compose, initialize_config_dir
85-
from hydra.core.global_hydra import GlobalHydra
8699

87100
GlobalHydra.instance().clear()
88101

89102
try:
90-
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config"), version_base=None):
103+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
91104
current_config = compose(config_name="ppo_megatron_trainer")
92105

93106
legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml")
94-
95107
current_dict = OmegaConf.to_container(current_config, resolve=True)
96108
legacy_dict = OmegaConf.to_container(legacy_config, resolve=True)
97109

98110
if "defaults" in current_dict:
99111
del current_dict["defaults"]
100112

101-
self._compare_configs_recursively(current_dict, legacy_dict)
113+
self._compare_configs_recursively(current_dict, legacy_dict, legacy_allow_missing=True)
102114
finally:
103115
GlobalHydra.instance().clear()
104116

117+
def test_load_component(self):
118+
"""Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly."""
119+
120+
GlobalHydra.instance().clear()
121+
configs_to_load = [
122+
("verl/trainer/config/actor", "dp_actor"),
123+
("verl/trainer/config/actor", "megatron_actor"),
124+
("verl/trainer/config/ref", "dp_ref"),
125+
("verl/trainer/config/ref", "megatron_ref"),
126+
("verl/trainer/config/rollout", "rollout"),
127+
]
128+
for config_dir, config_file in configs_to_load:
129+
try:
130+
with initialize_config_dir(config_dir=os.path.abspath(config_dir)):
131+
compose(config_name=config_file)
132+
finally:
133+
GlobalHydra.instance().clear()
134+
105135

106136
if __name__ == "__main__":
107137
unittest.main()

tests/utils/dataset/test_create_rl_sampler_on_cpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from omegaconf import DictConfig, OmegaConf
2323
from torch.utils.data import Dataset, RandomSampler
2424

25+
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
2526
from verl.trainer.main_ppo import create_rl_sampler
26-
from verl.utils.dataset.sampler import AbstractCurriculumSampler
2727

2828

2929
class RandomCurriculumSampler(AbstractCurriculumSampler):
@@ -77,10 +77,11 @@ def __len__(self):
7777
def test_create_custom_curriculum_samper():
7878
data_config = OmegaConf.create(
7979
{
80+
"dataloader_num_workers": 0,
8081
"sampler": {
8182
"class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu",
8283
"class_name": "RandomCurriculumSampler",
83-
}
84+
},
8485
}
8586
)
8687

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import gc
16+
17+
import torch
18+
import torch.distributed
19+
import torch.distributed as dist
20+
from omegaconf import OmegaConf
21+
from transformers import AutoConfig, AutoTokenizer
22+
23+
from verl import DataProto
24+
from verl.utils.distributed import initialize_global_process_group
25+
from verl.utils.model import compute_position_id_with_mask
26+
from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout
27+
28+
29+
def test_vllm_rollout_with_yarn_position_embeddings():
30+
"""
31+
Test the vLLM rollout with yarn position embeddings.
32+
"""
33+
34+
local_rank, rank, world_size = initialize_global_process_group()
35+
config = OmegaConf.create(
36+
{
37+
"model_path": "OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN",
38+
"prompt_length": 35000,
39+
"response_length": 512,
40+
"dtype": "bfloat16",
41+
"enforce_eager": True,
42+
"gpu_memory_utilization": 0.4,
43+
"enable_chunked_prefill": False,
44+
"free_cache_engine": False,
45+
"disable_log_stats": True,
46+
"max_model_len": 35000 + 512,
47+
"load_format": "auto",
48+
"val_kwargs": {
49+
"top_k": -1,
50+
"top_p": 1.0,
51+
"temperature": 0,
52+
"n": 1,
53+
"do_sample": False,
54+
},
55+
"tensor_model_parallel_size": 4,
56+
"trust_remote_code": True,
57+
"calculate_log_probs": False,
58+
"do_sample": False,
59+
"temperature": 0.0,
60+
"max_num_batched_tokens": 35000 + 512,
61+
}
62+
)
63+
64+
tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side="left")
65+
tokenizer.pad_token = tokenizer.eos_token
66+
model_hf_config = AutoConfig.from_pretrained(config.model_path)
67+
68+
# do_sample=False for temperate=0 deterministic
69+
input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False)
70+
71+
vllm_rollout = vLLMRollout(
72+
model_path=config.model_path,
73+
config=config,
74+
tokenizer=tokenizer,
75+
model_hf_config=model_hf_config,
76+
)
77+
# rollout
78+
rollout_response = vllm_rollout.generate_sequences(
79+
prompts=input_dataproto,
80+
)
81+
if rank == 0:
82+
print("VLLM Rollout Outputs:")
83+
print(tokenizer.batch_decode(rollout_response.batch["responses"][:], skip_special_tokens=False))
84+
for response in rollout_response.batch["responses"]:
85+
assert "<|im_end|>" in tokenizer.decode(response, skip_special_tokens=False), (
86+
"Response should contain <|im_end|> token"
87+
)
88+
print("Checks passed.")
89+
90+
del vllm_rollout
91+
gc.collect()
92+
torch.cuda.empty_cache()
93+
torch.cuda.ipc_collect()
94+
dist.barrier()
95+
torch.distributed.destroy_process_group()
96+
97+
98+
def prepare_input_dataproto(tokenizer, config, validate, do_sample=False):
99+
base_phrase = "Roses are red, sky is blue. " * 4096
100+
preencode_prompts = [
101+
# 32810 tokens > 32768 tokens
102+
[{"role": "user", "content": base_phrase + "Who won the Champions League in 2019?"}],
103+
[{"role": "user", "content": base_phrase + "The founder of Apple is"}],
104+
[{"role": "user", "content": base_phrase + "What's your name"}],
105+
]
106+
formatted_prompts = [
107+
tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
108+
for conversation in preencode_prompts
109+
]
110+
prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length)
111+
input_dataproto = DataProto.from_dict(
112+
{
113+
"input_ids": prompts["input_ids"],
114+
"attention_mask": prompts["attention_mask"],
115+
"position_ids": compute_position_id_with_mask(prompts["attention_mask"]),
116+
},
117+
meta_info={
118+
"bos_token_id": tokenizer.bos_token_id,
119+
"eos_token_id": tokenizer.eos_token_id,
120+
"pad_token_id": tokenizer.pad_token_id,
121+
"validate": validate,
122+
"do_sample": do_sample,
123+
"response_length": config.response_length,
124+
"temperature": config.temperature,
125+
},
126+
)
127+
return input_dataproto
128+
129+
130+
if __name__ == "__main__":
131+
test_vllm_rollout_with_yarn_position_embeddings()

tests/workers/rollout/test_sglang_multi_interaction.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ def test_initialize_multiple_interactions(self):
127127
# Mock model config
128128
mock_model_config = MagicMock()
129129
mock_model_config.max_position_embeddings = 2048
130+
# since this is a mock, we can set any rope scaling config
131+
# to test the rope_scaling logic at the same time of this test
132+
mock_model_config.rope_scaling = {
133+
"factor": 4.0,
134+
"original_max_position_embeddings": 32768,
135+
"type": "yarn",
136+
}
130137

131138
# Create SGLangRollout instance
132139
rollout = SGLangRollout(
@@ -173,6 +180,11 @@ def test_interaction_selection_by_name(self):
173180

174181
mock_model_config = MagicMock()
175182
mock_model_config.max_position_embeddings = 2048
183+
mock_model_config.rope_scaling = {
184+
"factor": 4.0,
185+
"original_max_position_embeddings": 32768,
186+
"type": "yarn",
187+
}
176188

177189
rollout = SGLangRollout(
178190
actor_module="mock_model",
@@ -278,6 +290,11 @@ def test_fallback_to_default_interaction(self):
278290

279291
mock_model_config = MagicMock()
280292
mock_model_config.max_position_embeddings = 2048
293+
mock_model_config.rope_scaling = {
294+
"factor": 4.0,
295+
"original_max_position_embeddings": 32768,
296+
"type": "yarn",
297+
}
281298

282299
rollout = SGLangRollout(
283300
actor_module="mock_model",
@@ -312,6 +329,11 @@ def test_error_on_missing_interaction(self):
312329

313330
mock_model_config = MagicMock()
314331
mock_model_config.max_position_embeddings = 2048
332+
mock_model_config.rope_scaling = {
333+
"factor": 4.0,
334+
"original_max_position_embeddings": 32768,
335+
"type": "yarn",
336+
}
315337

316338
rollout = SGLangRollout(
317339
actor_module="mock_model",
@@ -374,6 +396,11 @@ def test_backward_compatibility_no_interaction_config(self):
374396

375397
mock_model_config = MagicMock()
376398
mock_model_config.max_position_embeddings = 2048
399+
mock_model_config.rope_scaling = {
400+
"factor": 4.0,
401+
"original_max_position_embeddings": 32768,
402+
"type": "yarn",
403+
}
377404

378405
rollout = SGLangRollout(
379406
actor_module="mock_model",

verl/experimental/agent_loop/agent_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def _postprocess(self, inputs: List[AgentLoopOutput]) -> DataProto:
351351

352352

353353
async def get_trajectory_info(step, index):
354+
"""Get the trajectory info (step, sample_index, rollout_n) asynchrously"""
354355
trajectory_info = []
355356
rollout_n = 0
356357
for i in range(len(index)):
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,20 @@
2121

2222

2323
class AbstractSampler(Sampler[int]):
24+
"""Abstract interface for custom samplers."""
25+
2426
@abstractmethod
2527
def __init__(
2628
self,
2729
data_source: Sized,
28-
config: DictConfig,
30+
data_config: DictConfig,
2931
):
3032
pass
3133

3234

3335
class AbstractCurriculumSampler(AbstractSampler):
36+
"""Experimental interface for curriculum learning samplers."""
37+
3438
@abstractmethod
3539
def update(self, batch: DataProto) -> None:
3640
pass

0 commit comments

Comments
 (0)