Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
fd7a569
[FEAT] Support async multi-turn rollout with simulation feedback
kinza99 May 21, 2025
1c13235
[DOC] Update sglang multi-turn rollout doc
kinza99 May 22, 2025
6c2cf1b
[Update] update user interaction design
kinza99 May 28, 2025
c0176c1
[Update] add testing and fix bugs
kinza99 May 29, 2025
d560cf3
[Fix] fix some problems
kinza99 May 30, 2025
5fbdd11
Fix unit-test and separate examples from previous tool
SwordFaith May 30, 2025
cb4baa7
Fix megatron workers and formatting
SwordFaith May 30, 2025
ed070cc
[Update] merge the latest main version
kinza99 Jun 4, 2025
fbfdcd0
Add training script
SwordFaith Jun 4, 2025
4ea2f1a
Fix assertion
SwordFaith Jun 4, 2025
4b18b69
Fix max_turns
SwordFaith Jun 4, 2025
878b1aa
Fix init interaction missing
SwordFaith Jun 4, 2025
8023dcb
Fix interface
SwordFaith Jun 4, 2025
dc3157e
Lower gpu mem foot print
SwordFaith Jun 4, 2025
3104159
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 4, 2025
cc31550
Fix init interaction missing issue
SwordFaith Jun 4, 2025
ae9217a
[Fix] fix problem with exceeding max_new_tokens
kinza99 Jun 5, 2025
cbddd2e
Update training data path
SwordFaith Jun 5, 2025
0c4338f
Fix prompt in preprocess interaction
SwordFaith Jun 6, 2025
eec48a9
Add 0.5b train script
SwordFaith Jun 7, 2025
564c832
Fix gsm8k reward in multi-turn scene
SwordFaith Jun 7, 2025
b75af4e
Try fix race condition in sampling params update
SwordFaith Jun 9, 2025
ca47f66
Fix sglang rollout sampling params
SwordFaith Jun 9, 2025
64a8ad7
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 10, 2025
34378ce
Fix bug and redundant error
SwordFaith Jun 10, 2025
ddb4881
Remove format config
SwordFaith Jun 10, 2025
9592979
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 17, 2025
92a47ba
Fix interaction config default value bug
SwordFaith Jun 17, 2025
8f45437
Fix default value judge logic
SwordFaith Jun 17, 2025
98adc05
Fix arg issue
SwordFaith Jun 17, 2025
2d43631
Fix format error
SwordFaith Jun 17, 2025
9a86042
Fix other format errors
SwordFaith Jun 17, 2025
5cb7a60
Clean training scripts
SwordFaith Jun 17, 2025
7e83b72
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 18, 2025
7f426f3
Merge branch 'main' into duhe/multi_turns_with_feedback
SwordFaith Jun 18, 2025
13a9615
Fix sf tool test
SwordFaith Jun 18, 2025
3c4a351
Fix ci error
SwordFaith Jun 18, 2025
ffc9366
Fix aglang tests
SwordFaith Jun 18, 2025
c0a035c
Merge upstream/main into multi_turns_with_feedback
SwordFaith Jun 20, 2025
21110df
Add test and doc for interaction
SwordFaith Jun 20, 2025
fd11f7b
Try fix mcp tool test
SwordFaith Jun 20, 2025
a1a021f
Fix sglang mcp tools test
SwordFaith Jun 20, 2025
df5de70
Fix pre-commit run issue
SwordFaith Jun 20, 2025
2abda2e
Fix doc and test
SwordFaith Jun 20, 2025
432eeff
Merge branch 'main' into duhe/multi_turns_with_feedback
zhaochenyang20 Jun 20, 2025
c5fe07a
Try fix ci issues
SwordFaith Jun 21, 2025
df42678
Fix chat completion arg bug
SwordFaith Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Fix] fix problem with exceeding max_new_tokens
  • Loading branch information
kinza99 committed Jun 5, 2025
commit ae9217a92f20c3e817b69fe74eaffde3e9f091b8
107 changes: 107 additions & 0 deletions examples/data_preprocess/gsm8k_multiturn_w_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""

import argparse
import os
import re

import datasets

from verl.utils.hdfs_io import copy, makedirs


def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/gsm8k")
parser.add_argument("--hdfs_dir", default=None)

args = parser.parse_args()

data_source = "openai/gsm8k"
dataset = datasets.load_dataset(data_source, "main")

train_dataset = dataset["train"]
test_dataset = dataset["test"]

instruction_following = "Let's think step by step and output the final answer after `####`."

# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
question_raw = example.pop("question")

question = question_raw + " " + instruction_following

answer_raw = example.pop("answer")
solution = extract_solution(answer_raw)
data = {
"data_source": data_source,
"prompt": [
{
"role": "system",
"content": (
"You are a math expert. You are given a question and you need to solve it step by step. "
"Reasoning step by step before any tool call. "
"You should use the `calc_gsm8k_reward` tool after step by step solving the question, "
"before generate final answer at least once and refine your answer if necessary. "
"Put your final answer in the format of `#### <answer>`."
),
},
{
"role": "user",
"content": question,
},
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
"interaction_kwargs": {
"query": question,
"ground_truth": solution,
},
},
}
return data

return process_fn

train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)

local_dir = args.local_dir
hdfs_dir = args.hdfs_dir

train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))

if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
11 changes: 6 additions & 5 deletions verl/workers/rollout/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ class AsyncRolloutRequest(BaseModel):
"assistant_prefix_msg": "\n<|im_start|>assistant\n",
"assistant_suffix_msg": "<|im_end|>",
"merge_tool_response": True,
"tool_prefix_msg": "\n<|im_start|>user",
"tool_prefix_msg": "\n<|im_start|>user\n",
"tool_suffix_msg": "<|im_end|>",
"tool_response_prefix_msg": "\n<tool_response>\n",
"tool_response_suffix_msg": "\n</tool_response>",
"user_prefix_msg": "\n<|im_start|>user",
"user_prefix_msg": "\n<|im_start|>user\n",
"user_suffix_msg": "<|im_end|>",
},
}
Expand All @@ -132,7 +132,8 @@ def add_user_message(
prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False)
suffix_msg = self.format_config[format]["user_suffix_msg"]
suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False)

assistant_prefix_msg = self.format_config[format]["assistant_prefix_msg"]
assistant_prefix_token_ids = tokenizer.encode(assistant_prefix_msg, add_special_tokens=False)
content_token_ids = tokenizer.encode(content, add_special_tokens=False)
if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids:
append_token_ids = content_token_ids
Expand All @@ -144,8 +145,8 @@ def add_user_message(
max_len = max(len(prefix_token_ids), len(suffix_token_ids))
raise ValueError(f"Unsupported end of message format: {tokenizer.decode(self.input_ids[-max_len:])}, {tokenizer.decode(self.input_ids)=}, {self.messages=}")
if not already_over_long:
append_token_ids += suffix_token_ids
_loss_mask += [0] * len(suffix_token_ids)
append_token_ids += (suffix_token_ids + assistant_prefix_token_ids)
_loss_mask += ([0] * len(suffix_token_ids) + [0] * len(assistant_prefix_token_ids))
self.input_ids += append_token_ids
_attention_mask = [1] * len(append_token_ids)
self.attention_mask += _attention_mask
Expand Down
6 changes: 6 additions & 0 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,11 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
_tool_schemas = []
_tools_kwargs = {}

if self.interaction is not None:
_interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][data_idx]
else:
_interaction_kwargs = None

req = AsyncRolloutRequest(
batch_data_id=data_idx,
rollout_offset=rollout_offset,
Expand All @@ -1046,6 +1051,7 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
messages=[Message.model_validate(msg) for msg in raw_prompt],
tools=_tool_schemas,
tools_kwargs=_tools_kwargs,
interaction_kwargs=_interaction_kwargs,
input_ids=_input_ids,
prompt_ids=_input_ids,
response_ids=[],
Expand Down