Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
fd7a569
[FEAT] Support async multi-turn rollout with simulation feedback
kinza99 May 21, 2025
1c13235
[DOC] Update sglang multi-turn rollout doc
kinza99 May 22, 2025
6c2cf1b
[Update] update user interaction design
kinza99 May 28, 2025
c0176c1
[Update] add testing and fix bugs
kinza99 May 29, 2025
d560cf3
[Fix] fix some problems
kinza99 May 30, 2025
5fbdd11
Fix unit-test and separate examples from previous tool
SwordFaith May 30, 2025
cb4baa7
Fix megatron workers and formatting
SwordFaith May 30, 2025
ed070cc
[Update] merge the latest main version
kinza99 Jun 4, 2025
fbfdcd0
Add training script
SwordFaith Jun 4, 2025
4ea2f1a
Fix assertion
SwordFaith Jun 4, 2025
4b18b69
Fix max_turns
SwordFaith Jun 4, 2025
878b1aa
Fix init interaction missing
SwordFaith Jun 4, 2025
8023dcb
Fix interface
SwordFaith Jun 4, 2025
dc3157e
Lower gpu mem foot print
SwordFaith Jun 4, 2025
3104159
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 4, 2025
cc31550
Fix init interaction missing issue
SwordFaith Jun 4, 2025
ae9217a
[Fix] fix problem with exceeding max_new_tokens
kinza99 Jun 5, 2025
cbddd2e
Update training data path
SwordFaith Jun 5, 2025
0c4338f
Fix prompt in preprocess interaction
SwordFaith Jun 6, 2025
eec48a9
Add 0.5b train script
SwordFaith Jun 7, 2025
564c832
Fix gsm8k reward in multi-turn scene
SwordFaith Jun 7, 2025
b75af4e
Try fix race condition in sampling params update
SwordFaith Jun 9, 2025
ca47f66
Fix sglang rollout sampling params
SwordFaith Jun 9, 2025
64a8ad7
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 10, 2025
34378ce
Fix bug and redundant error
SwordFaith Jun 10, 2025
ddb4881
Remove format config
SwordFaith Jun 10, 2025
9592979
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 17, 2025
92a47ba
Fix interaction config default value bug
SwordFaith Jun 17, 2025
8f45437
Fix default value judge logic
SwordFaith Jun 17, 2025
98adc05
Fix arg issue
SwordFaith Jun 17, 2025
2d43631
Fix format error
SwordFaith Jun 17, 2025
9a86042
Fix other format errors
SwordFaith Jun 17, 2025
5cb7a60
Clean training scripts
SwordFaith Jun 17, 2025
7e83b72
Merge remote-tracking branch 'upstream/main' into multi_turns_with_fe…
SwordFaith Jun 18, 2025
7f426f3
Merge branch 'main' into duhe/multi_turns_with_feedback
SwordFaith Jun 18, 2025
13a9615
Fix sf tool test
SwordFaith Jun 18, 2025
3c4a351
Fix ci error
SwordFaith Jun 18, 2025
ffc9366
Fix aglang tests
SwordFaith Jun 18, 2025
c0a035c
Merge upstream/main into multi_turns_with_feedback
SwordFaith Jun 20, 2025
21110df
Add test and doc for interaction
SwordFaith Jun 20, 2025
fd11f7b
Try fix mcp tool test
SwordFaith Jun 20, 2025
a1a021f
Fix sglang mcp tools test
SwordFaith Jun 20, 2025
df5de70
Fix pre-commit run issue
SwordFaith Jun 20, 2025
2abda2e
Fix doc and test
SwordFaith Jun 20, 2025
432eeff
Merge branch 'main' into duhe/multi_turns_with_feedback
zhaochenyang20 Jun 20, 2025
c5fe07a
Try fix ci issues
SwordFaith Jun 21, 2025
df42678
Fix chat completion arg bug
SwordFaith Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[Update] update user interaction design
  • Loading branch information
kinza99 committed May 29, 2025
commit 6c2cf1b31f590bdcb70ecbe9cf653dae7f860ea4

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
interaction:
- class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction"
config: {}
78 changes: 0 additions & 78 deletions verl/feedbacks/base.py

This file was deleted.

File renamed without changes.
63 changes: 63 additions & 0 deletions verl/interactions/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Tuple


class BaseInteraction:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.name: str = config.get("name", "interaction_agent") # More general agent default role name

async def start_interaction(self) -> str: # More clear interaction start method
"""
Initializes a new interaction session and returns its unique ID.
Simulates: get id + state init
"""
# ...implement the logic to get ID and initialize state...
interaction_id = "some_unique_id"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return interaction_id

async def generate_response(self, messages: Any) -> Tuple[bool, str, float, Dict[str, Any]]: # More clear response generation method
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for missing instance_id in doc, it would be better keep instance id logic in tool to track session.

"""
Generates a response for the current turn of interaction.
Returns a tuple containing:
- should_terminate_sequence (bool): True if the interaction sequence should end.
- response_content (str): The textual content of the response.
- current_turn_score (float): The score for this specific turn/response.
- additional_data (dict): Any extra information or metadata.
"""
should_terminate_sequence: bool = False # if True, end rollout
response_content: str = "Your current result seems acceptable."
current_turn_score: float = 0.8
additional_data: Dict[str, Any] = {}
return should_terminate_sequence, response_content, current_turn_score, additional_data

async def calculate_score(self) -> float: # More clear score calculation method
"""
Calculates a score for the interaction,
potentially considering aspects like partial exposure & in-context task switching.
should be invoke at turn-level
"""
# ...implement the logic to calculate turn-level score...
score = 0.0
return score
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about raise NotImplementedError

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about raise NotImplementedError

The calc_score method is not essential for implementing interactions. Here, we use 0.0 to ensure it has no impact on the aggregated reward.


async def finalize_interaction(self) -> None: # More clear interaction end and resource release method
"""
Finalizes the interaction session and releases any associated state or resources.
Simulates: release state
"""
# ...implement the logic to release state...
pass
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@

from verl.utils.reward_score import gsm8k

from .base import BaseFeedback
from .base import BaseInteraction

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class Gsm8kFeedback(BaseFeedback):
"""A demo feedback for calculating the reward of gsm8k.
class Gsm8kInteraction(BaseInteraction):
"""A demo interaction for calculating the reward of gsm8k.

- `create`: create a feedback instance for a trajectory.
- `get_feedback`: get the feedback of the user.
- `release`: release the feedback instance.
- `start_interaction`: start a interaction instance for a trajectory.
- `generate_response`: generate the response of the user.
- `calculate_score`: calculate the score of the interaction.
- `finalize_interaction`: finalize the interaction instance.
"""

def __init__(self, config: dict):
super().__init__(config)
self._instance_dict = {}

async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:
async def start_interaction(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str:
if instance_id is None:
instance_id = str(uuid4())
self._instance_dict[instance_id] = {
Expand All @@ -48,30 +49,30 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional
}
return instance_id

async def get_feedback(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[str, float, dict]:
content = ''
async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[str, float, dict]:
content = ""
for i in range(len(messages) - 1, -1, -1):
item = messages[i]
if item.get('role') == 'user':
content = item.get('content')
if item.get("role") == "user":
content = item.get("content")
break

if content.startswith("#### "):
self._instance_dict[instance_id]["response"] = content
else:
self._instance_dict[instance_id]["response"] = "#### " + content

reward = await self.calc_reward(instance_id)
reward = await self.calculate_score(instance_id)
if reward == 1.0:
feedback = "Your response is correct!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feedback -> response

go_on = False
should_terminate_sequence = True
else:
feedback = "Your response is incorrect! You need to reflect on your answer and try again."
go_on = True
should_terminate_sequence = False

return f"{feedback=} {reward=}", go_on, {}
return should_terminate_sequence, f"{feedback=}", reward, {}

async def calc_reward(self, instance_id: str, **kwargs) -> float:
async def calculate_score(self, instance_id: str, **kwargs) -> float:
return gsm8k.compute_score(
self._instance_dict[instance_id]["response"],
self._instance_dict[instance_id]["ground_truth"],
Expand All @@ -80,5 +81,5 @@ async def calc_reward(self, instance_id: str, **kwargs) -> float:
score=1.0,
)

async def release(self, instance_id: str, **kwargs) -> None:
async def finalize_interaction(self, instance_id: str, **kwargs) -> None:
del self._instance_dict[instance_id]
6 changes: 3 additions & 3 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ actor_rollout_ref:
do_sample: False # default eager for validation
multi_turn:
enable: False # should set rollout.name to sglang_async if True
max_turns: null # null for no limit (default max_length // 3)
max_assistant_turns: null # null for no limit (default max_length // 3)
tool_config_path: null # null for no tool
feedback_config_path: null # null for no feedback
interaction_config_path: null # null for no interaction
format: chatml # chatml, more formats will be supported in the future
user_max_turns: 1
max_user_turns: null

critic:
rollout_n: ${actor_rollout_ref.rollout.n}
Expand Down
9 changes: 5 additions & 4 deletions verl/workers/rollout/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AsyncRolloutRequest(BaseModel):
loss_mask: List[int]
prompt_loss_mask: List[int]
response_loss_mask: List[int]
reward_scores: Dict[str, float]
reward_scores: Dict[str, List[float]]
max_response_len: int = 8192
max_model_len: int = 32768
metrics: Dict[str, List[Any]] = {}
Expand All @@ -108,7 +108,7 @@ class AsyncRolloutRequest(BaseModel):
},
"user_prefix_msg": "\n<|im_start|>user",
"user_suffix_msg": "<|im_end|>",
}
},
}

def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]:
Expand All @@ -118,6 +118,7 @@ def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> list[int]:
add_generation_prompt=True,
tokenize=True,
)

def add_user_message(
self,
tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -158,7 +159,7 @@ def add_user_message(
raise ValueError(f"Unsupported format: {format}")
assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=},
{len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}"""

def add_assistant_message(
self,
tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -263,7 +264,7 @@ def update_metrics(self, metrics: Any, tool_id: str) -> None:
def finalize(
self,
tokenizer: PreTrainedTokenizer,
reward_scores: Dict[str, float],
reward_scores: Dict[str, List[float]],
finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP,
) -> None:
self.state = AsyncRolloutRequestStateEnum.COMPLETED
Expand Down
Loading