Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
fd7a569
[FEAT] Support async multi-turn rollout with simulation feedback
kinza99 May 21, 2025
1c13235
[DOC] Update sglang multi-turn rollout doc
kinza99 May 22, 2025
6c2cf1b
[Update] update user interaction design
kinza99 May 28, 2025
c0176c1
[Update] add testing and fix bugs
kinza99 May 29, 2025
d560cf3
[Fix] fix some problems
kinza99 May 30, 2025
5fbdd11
Fix unit-test and separate examples from previous tool
SwordFaith May 30, 2025
cb4baa7
Fix megatron workers and formatting
SwordFaith May 30, 2025
ed070cc
[Update] merge the latest main version
kinza99 Jun 4, 2025
fbfdcd0
Add training script
SwordFaith Jun 4, 2025
4ea2f1a
Fix assertion
SwordFaith Jun 4, 2025
4b18b69
Fix max_turns
SwordFaith Jun 4, 2025
878b1aa
Fix init interaction missing
SwordFaith Jun 4, 2025
8023dcb
Fix interface
SwordFaith Jun 4, 2025
dc3157e
Lower gpu mem foot print
SwordFaith Jun 4, 2025
3104159
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 4, 2025
cc31550
Fix init interaction missing issue
SwordFaith Jun 4, 2025
ae9217a
[Fix] fix problem with exceeding max_new_tokens
kinza99 Jun 5, 2025
cbddd2e
Update training data path
SwordFaith Jun 5, 2025
0c4338f
Fix prompt in preprocess interaction
SwordFaith Jun 6, 2025
eec48a9
Add 0.5b train script
SwordFaith Jun 7, 2025
564c832
Fix gsm8k reward in multi-turn scene
SwordFaith Jun 7, 2025
b75af4e
Try fix race condition in sampling params update
SwordFaith Jun 9, 2025
ca47f66
Fix sglang rollout sampling params
SwordFaith Jun 9, 2025
64a8ad7
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 10, 2025
34378ce
Fix bug and redundant error
SwordFaith Jun 10, 2025
ddb4881
Remove format config
SwordFaith Jun 10, 2025
9592979
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 17, 2025
92a47ba
Fix interaction config default value bug
SwordFaith Jun 17, 2025
8f45437
Fix default value judge logic
SwordFaith Jun 17, 2025
98adc05
Fix arg issue
SwordFaith Jun 17, 2025
2d43631
Fix format error
SwordFaith Jun 17, 2025
9a86042
Fix other format errors
SwordFaith Jun 17, 2025
5cb7a60
Clean training scripts
SwordFaith Jun 17, 2025
7e83b72
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 18, 2025
7f426f3
Merge branch 'main' into duhe/multi_turns_with_feedback
SwordFaith Jun 18, 2025
13a9615
Fix sf tool test
SwordFaith Jun 18, 2025
3c4a351
Fix ci error
SwordFaith Jun 18, 2025
ffc9366
Fix aglang tests
SwordFaith Jun 18, 2025
c0a035c
Merge upstream/main into multi_turns_with_feedback
SwordFaith Jun 20, 2025
21110df
Add test and doc for interaction
SwordFaith Jun 20, 2025
fd11f7b
Try fix mcp tool test
SwordFaith Jun 20, 2025
a1a021f
Fix sglang mcp tools test
SwordFaith Jun 20, 2025
df5de70
Fix pre-commit run issue
SwordFaith Jun 20, 2025
2abda2e
Fix doc and test
SwordFaith Jun 20, 2025
432eeff
Merge branch 'main' into duhe/multi_turns_with_feedback
zhaochenyang20 Jun 20, 2025
c5fe07a
Try fix ci issues
SwordFaith Jun 21, 2025
df42678
Fix chat completion arg bug
SwordFaith Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix max_turns
  • Loading branch information
SwordFaith committed Jun 4, 2025
commit 4b18b693bb56df43e72a62fddc5bafc6fa2f2e38
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
hydra:
searchpath:
- file://verl/trainer/config

defaults:
- ppo_trainer
- _self_

data:
max_prompt_length: 1024
max_response_length: 1024
train_batch_size: 256
return_raw_chat: True

actor_rollout_ref:
hybrid_engine: True
rollout:
name: sglang
multi_turn:
enable: True
max_user_turns: 5
format: qwen
# tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml"
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ actor_rollout_ref:
name: sglang_async
multi_turn:
enable: True
max_turns: 5
max_assistant_turns: 5
tool_config_path: "./config/tool_config/sandbox_fusion_tool_config.yaml"
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ actor_rollout_ref:
name: sglang_async
multi_turn:
enable: True
max_turns: 2
max_assistant_turns: 2
format: qwen
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ HOME=/user/longxiang1

python3 -m verl.trainer.main_ppo \
--config-path="$CONFIG_PATH" \
--config-name='gsm8k_multiturn_grpo' \
--config-name='gsm8k_multiturn_grpo_w_interaction' \
algorithm.adv_estimator=grpo \
data.train_batch_size=$TRAIN_BATCH_SIZE \
data.max_prompt_length=1024 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.rollout.name=sglang_async \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.multi_turn.max_turns=2 \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data):
search_rollout_config.multi_turn.max_turns = 1
search_rollout_config.multi_turn.max_assistant_turns = 1
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0]
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_e
# Mock search tool execution to return predefined responses
mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array]

search_rollout_config.multi_turn.max_turns = 10
search_rollout_config.multi_turn.max_assistant_turns = 10
rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)

rollout._tool_map["search"].retrieval_service_url = "mock://dummy"
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_e
(tool_return_array[1], 0.0, {"status": "success"}),
] * 100

search_rollout_config.multi_turn.max_turns = 10
search_rollout_config.multi_turn.max_assistant_turns = 10
rollout = SGLangRollout(
actor_module="",
config=search_rollout_config,
Expand Down
6 changes: 3 additions & 3 deletions tests/workers/rollout/test_sglang_async_rollout_sf_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbo
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
sandbox_fusion_rollout_config.multi_turn.max_turns = 1
sandbox_fusion_rollout_config.multi_turn.max_assistant_turns = 1
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
req = MagicMock(wraps=req, spec=AsyncRolloutRequest)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusi
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
sandbox_fusion_rollout_config.multi_turn.max_turns = 10
sandbox_fusion_rollout_config.multi_turn.max_assistant_turns = 10
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbo
@patch.object(SGLangRollout, "_init_inference_engine", return_value=None)
@patch.object(SGLangRollout, "_init_sampling_params", return_value=None)
def test_tool_call_batch_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data):
sandbox_fusion_rollout_config.multi_turn.max_turns = 10
sandbox_fusion_rollout_config.multi_turn.max_assistant_turns = 10
rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config)
self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url
req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0]
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ actor_rollout_ref:
do_sample: False # default eager for validation
multi_turn:
enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
max_turns: null # null for no limit (default max_length // 3)
max_assistant_turns: null # null for no limit (default max_length // 3)
tool_config_path: null # null for no tool
format: chatml # chatml, more formats will be supported in the future

Expand Down
8 changes: 5 additions & 3 deletions verl/workers/rollout/sglang_rollout/async_sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,11 @@ def _verify_config(self, model_hf_config):
assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length):
{self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}"""
assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length"
# currently max_turns stand for max number of tool calls
if self.config.multi_turn.max_turns is None:
self.config.multi_turn.max_turns = self.config.max_model_len // 3
# currently max_assistant_turns stand for max number of tool calls
if self.config.multi_turn.max_assistant_turns is None:
self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3
if self.config.multi_turn.max_user_turns is None:
self.config.multi_turn.max_user_turns = self.config.max_model_len // 3

def _init_inference_engine(self, trust_remote_code, actor_module, port):
# initialize the inference engine
Expand Down
10 changes: 6 additions & 4 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,11 @@ def _verify_config(self, model_hf_config):
assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length):
{self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}"""
assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length"
# currently max_turns stand for max number of tool calls
if self.config.multi_turn.max_turns is None:
self.config.multi_turn.max_turns = self.config.max_model_len // 3
# currently max_assistant_turns stand for max number of tool calls
if self.config.multi_turn.max_assistant_turns is None:
self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3
if self.config.multi_turn.max_user_turns is None:
self.config.multi_turn.max_user_turns = self.config.max_model_len // 3

def _init_inference_engine(self, trust_remote_code, actor_module, port):
# initialize the inference engine
Expand Down Expand Up @@ -657,7 +659,7 @@ async def _async_rollout_a_request(
user_turns = 0
user_turn_rewards = []

while current_turns < self.config.multi_turn.max_turns:
while current_turns < self.config.multi_turn.max_assistant_turns:
if _req.state == AsyncRolloutRequestStateEnum.PENDING:
await self._handle_pending_state(_req)
_req.state = AsyncRolloutRequestStateEnum.RUNNING
Expand Down