Skip to content

Commit f1c6971

Browse files
HJSangHejian Sanggemini-code-assist[bot]
authored
[recipe, rollout] feat: enable gpt-oss training for tool agent add gpt-oss for retool recipe (verl-project#3837)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. * add the tool response parsing logic for gpt-oss models * add training recipe for retool ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Hejian Sang <hsang@linkedin.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent aa731f3 commit f1c6971

File tree

4 files changed

+190
-23
lines changed

4 files changed

+190
-23
lines changed

recipe/langgraph_agent/chat_model.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,12 @@
3838

3939
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
4040
from verl.experimental.agent_loop.tool_parser import ToolParser
41+
from verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually
4142

4243
logger = logging.getLogger(__file__)
4344
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
4445

4546

46-
def format_tool_response_manually(tool_message: dict, tool_call_name: str) -> str:
47-
"""Manually format tool response without using tokenizer template.
48-
49-
Args:
50-
tool_message: Tool message dictionary with 'content' field
51-
tool_call_name: Name of the tool that was called
52-
53-
Returns:
54-
Formatted tool response string
55-
"""
56-
content = tool_message["content"]
57-
return f"<|start|>functions.{tool_call_name} to=assistant<|channel|>commentary<|message|>{content}<|end|>"
58-
59-
6047
class MaxTokenExceededError(Exception):
6148
"""Indicate that history chat messages + tool message exceeds LLM max_tokens."""
6249

@@ -235,14 +222,14 @@ async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple
235222
actual_tool_name = tool_msg.get("name", "unknown")
236223
if actual_tool_name == "unknown":
237224
logger.error(f"actual_tool_name: {actual_tool_name}")
238-
formatted = format_tool_response_manually(tool_msg, actual_tool_name)
225+
formatted = format_gpt_oss_tool_response_manually(tool_msg["content"], actual_tool_name)
239226
tool_response_texts.append(formatted)
240-
# need to add generation tokens for gpt-oss manually since add_generation_prompt is True
241-
tool_response_texts.append("<|start|>assistant")
242227

243228
# Tokenize the manually formatted tool responses
244229
tool_response_text = "".join(tool_response_texts)
245-
print(f"tool_response_text: {tool_response_text}")
230+
# need to add generation tokens for gpt-oss manually since add_generation_prompt is True
231+
tool_response_text = add_generation_prompt_for_gpt_oss(tool_response_text)
232+
logger.debug(f"tool_response_text: {tool_response_text}")
246233

247234
tool_response_ids = await loop.run_in_executor(
248235
None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)

recipe/retool/run_gpt_oss_ppo.sh

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
set -x
2+
3+
# ================= data/model/tool =================
4+
HDFS_ROOT=${HDFS_ROOT:-$PWD}
5+
DATA_ROOT=${DATA_ROOT:-$PWD}
6+
7+
dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k
8+
aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024
9+
aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025
10+
actor_model_path=lmsys/gpt-oss-20b-bf16
11+
critic_model_path=$actor_model_path
12+
13+
train_files="['$dapo_math_17k']"
14+
test_files="['$aime_2025']"
15+
16+
# tool
17+
tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml
18+
19+
# wandb
20+
project_name=wuxibin_retool
21+
experiment_name=gpt-oss-20b-bf16_ppo
22+
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
23+
24+
# ================= algorithm =================
25+
adv_estimator=gae
26+
27+
use_kl_in_reward=False
28+
kl_coef=0.0
29+
use_kl_loss=False
30+
kl_loss_coef=0.0
31+
32+
clip_ratio_low=0.2
33+
clip_ratio_high=0.28
34+
35+
max_turns=8
36+
max_prompt_length=2048
37+
max_response_length=16384
38+
actor_lr=1e-6
39+
critic_lr=2e-6
40+
gae_gamma=1.0
41+
gae_lam=1.0
42+
43+
critic_warmup=20
44+
45+
train_batch_size=512
46+
ppo_mini_batch_size=512
47+
n_resp_per_prompt_val=30
48+
49+
# ================= perfomance =================
50+
infer_tp=4 # vllm
51+
train_sp=4 # train
52+
53+
offload=True
54+
55+
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 2 ))
56+
critic_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
57+
58+
59+
python3 -m verl.trainer.main_ppo \
60+
algorithm.adv_estimator=$adv_estimator \
61+
algorithm.use_kl_in_reward=$use_kl_in_reward \
62+
algorithm.kl_ctrl.kl_coef=$kl_coef \
63+
algorithm.gamma=$gae_gamma \
64+
algorithm.lam=$gae_lam \
65+
data.train_files="$train_files" \
66+
data.val_files="$test_files" \
67+
data.return_raw_chat=True \
68+
data.train_batch_size=$train_batch_size \
69+
data.max_prompt_length=$max_prompt_length \
70+
data.max_response_length=$max_response_length \
71+
data.filter_overlong_prompts=True \
72+
+data.apply_chat_template_kwargs.reasoning_effort=medium \
73+
data.truncation='error' \
74+
data.custom_cls.path=recipe/retool/retool.py \
75+
data.custom_cls.name=CustomRLHFDataset \
76+
custom_reward_function.path=recipe/retool/retool.py \
77+
custom_reward_function.name=compute_score \
78+
actor_rollout_ref.model.path=$actor_model_path \
79+
actor_rollout_ref.model.use_remove_padding=True \
80+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
81+
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
82+
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
83+
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
84+
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
85+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
86+
actor_rollout_ref.actor.optim.lr=$actor_lr \
87+
actor_rollout_ref.actor.use_dynamic_bsz=True \
88+
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
89+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
90+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
91+
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
92+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
93+
actor_rollout_ref.rollout.name=sglang \
94+
actor_rollout_ref.rollout.mode=async \
95+
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
96+
actor_rollout_ref.rollout.multi_turn.enable=True \
97+
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
98+
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
99+
actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \
100+
actor_rollout_ref.rollout.multi_turn.format=gpt-oss \
101+
+actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \
102+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
103+
actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \
104+
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
105+
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
106+
critic.optim.lr=$critic_lr \
107+
critic.model.use_remove_padding=True \
108+
critic.model.path=$critic_model_path \
109+
critic.model.enable_gradient_checkpointing=True \
110+
critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \
111+
critic.ulysses_sequence_parallel_size=$train_sp \
112+
critic.model.fsdp_config.param_offload=$offload \
113+
critic.model.fsdp_config.optimizer_offload=$offload \
114+
trainer.critic_warmup=$critic_warmup \
115+
trainer.logger=['console','wandb'] \
116+
trainer.project_name=$project_name \
117+
trainer.experiment_name=$experiment_name \
118+
trainer.n_gpus_per_node=8 \
119+
trainer.val_before_train=True \
120+
trainer.log_val_generations=100 \
121+
trainer.nnodes=2 \
122+
trainer.save_freq=30 \
123+
trainer.default_local_dir=$default_local_dir \
124+
trainer.test_freq=5 \
125+
trainer.total_epochs=1 $@

verl/experimental/agent_loop/tool_agent_loop.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
2424
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
25+
from verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually
2526
from verl.interactions.base import BaseInteraction
2627
from verl.interactions.utils.interaction_registry import initialize_interactions_from_config
2728
from verl.tools.schemas import ToolResponse
@@ -261,8 +262,10 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
261262
new_images_this_turn: list[Any] = [] # Local variable instead of agent_data attribute
262263

263264
tasks = []
265+
tool_call_names = []
264266
for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:
265267
tasks.append(self._call_tool(tool_call, agent_data.tools_kwargs))
268+
tool_call_names.append(tool_call.name)
266269

267270
with simple_timer("tool_calls", agent_data.metrics):
268271
responses = await asyncio.gather(*tasks)
@@ -341,11 +344,25 @@ async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentSt
341344
model_inputs = self.processor(text=[raw_tool_response], images=current_images, return_tensors="pt")
342345
response_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
343346
else:
344-
response_ids = await self.loop.run_in_executor(
345-
None,
346-
lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),
347-
)
348-
response_ids = response_ids[len(self.system_prompt) :]
347+
if self.tool_parser == "gpt-oss":
348+
logger.info("manually format tool responses for gpt-oss")
349+
# Format tool responses manually
350+
tool_response_texts = []
351+
for i, tool_msg in enumerate(add_messages):
352+
actual_tool_name = tool_call_names[i]
353+
formatted = format_gpt_oss_tool_response_manually(tool_msg["content"], actual_tool_name)
354+
tool_response_texts.append(formatted)
355+
356+
tool_response_text = add_generation_prompt_for_gpt_oss("".join(tool_response_texts))
357+
response_ids = await self.loop.run_in_executor(
358+
None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)
359+
)
360+
else:
361+
response_ids = await self.loop.run_in_executor(
362+
None,
363+
lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),
364+
)
365+
response_ids = response_ids[len(self.system_prompt) :]
349366
if len(agent_data.response_mask) + len(response_ids) >= self.response_length:
350367
return AgentState.TERMINATED
351368
# Update prompt_ids and response_mask
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# tokenizer.apply_chat_template is not working properly for gpt-oss model.
16+
# Because the chat template requires tool call messages to parse tool response messages
17+
# so we need to format the tool response manually.
18+
def format_gpt_oss_tool_response_manually(tool_response: str, tool_call_name: str) -> str:
19+
"""Format tool response for gpt-oss model.
20+
Args:
21+
tool_response: Tool response string
22+
tool_call_name: Name of the tool that was called
23+
24+
Returns:
25+
Formatted tool response string
26+
"""
27+
return f"<|start|>functions.{tool_call_name} to=assistant<|channel|>commentary<|message|>{tool_response}<|end|>"
28+
29+
30+
def add_generation_prompt_for_gpt_oss(message_content: str) -> str:
31+
"""Add generation prompt for gpt-oss model.
32+
Args:
33+
message_content: Message content string
34+
35+
Returns:
36+
Message content string with generation prompt
37+
"""
38+
return message_content + "<|start|>assistant"

0 commit comments

Comments
 (0)