-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[sglang] feat: Support async multi-turn rollout with simulation feedback in sglang #1630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
fd7a569
1c13235
6c2cf1b
c0176c1
d560cf3
5fbdd11
cb4baa7
ed070cc
fbfdcd0
4ea2f1a
4b18b69
878b1aa
8023dcb
dc3157e
3104159
cc31550
ae9217a
cbddd2e
0c4338f
eec48a9
564c832
b75af4e
ca47f66
64a8ad7
34378ce
ddb4881
9592979
92a47ba
8f45437
98adc05
2d43631
9a86042
5cb7a60
7e83b72
7f426f3
13a9615
3c4a351
ffc9366
c0a035c
21110df
fd11f7b
a1a021f
df5de70
2abda2e
432eeff
c5fe07a
df42678
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| interaction: | ||
| - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" | ||
| config: {} |
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright 2023-2024 SGLang Team | ||
| # Copyright 2025 ModelBest Inc. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import Any, Dict, Tuple | ||
|
|
||
|
|
||
| class BaseInteraction: | ||
| def __init__(self, config: Dict[str, Any]): | ||
| self.config = config | ||
| self.name: str = config.get("name", "interaction_agent") # More general agent default role name | ||
|
|
||
| async def start_interaction(self) -> str: # More clear interaction start method | ||
| """ | ||
| Initializes a new interaction session and returns its unique ID. | ||
| Simulates: get id + state init | ||
| """ | ||
| # ...implement the logic to get ID and initialize state... | ||
| interaction_id = "some_unique_id" | ||
| return interaction_id | ||
|
|
||
| async def generate_response(self, messages: Any) -> Tuple[bool, str, float, Dict[str, Any]]: # More clear response generation method | ||
|
||
| """ | ||
| Generates a response for the current turn of interaction. | ||
| Returns a tuple containing: | ||
| - should_terminate_sequence (bool): True if the interaction sequence should end. | ||
| - response_content (str): The textual content of the response. | ||
| - current_turn_score (float): The score for this specific turn/response. | ||
| - additional_data (dict): Any extra information or metadata. | ||
| """ | ||
| should_terminate_sequence: bool = False # if True, end rollout | ||
| response_content: str = "Your current result seems acceptable." | ||
| current_turn_score: float = 0.8 | ||
| additional_data: Dict[str, Any] = {} | ||
| return should_terminate_sequence, response_content, current_turn_score, additional_data | ||
|
|
||
| async def calculate_score(self) -> float: # More clear score calculation method | ||
| """ | ||
| Calculates a score for the interaction, | ||
| potentially considering aspects like partial exposure & in-context task switching. | ||
| should be invoke at turn-level | ||
| """ | ||
| # ...implement the logic to calculate turn-level score... | ||
| score = 0.0 | ||
| return score | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The |
||
|
|
||
| async def finalize_interaction(self) -> None: # More clear interaction end and resource release method | ||
| """ | ||
| Finalizes the interaction session and releases any associated state or resources. | ||
| Simulates: release state | ||
| """ | ||
| # ...implement the logic to release state... | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,25 +20,26 @@ | |
|
|
||
| from verl.utils.reward_score import gsm8k | ||
|
|
||
| from .base import BaseFeedback | ||
| from .base import BaseInteraction | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) | ||
|
|
||
|
|
||
| class Gsm8kFeedback(BaseFeedback): | ||
| """A demo feedback for calculating the reward of gsm8k. | ||
| class Gsm8kInteraction(BaseInteraction): | ||
| """A demo interaction for calculating the reward of gsm8k. | ||
|
|
||
| - `create`: create a feedback instance for a trajectory. | ||
| - `get_feedback`: get the feedback of the user. | ||
| - `release`: release the feedback instance. | ||
| - `start_interaction`: start a interaction instance for a trajectory. | ||
| - `generate_response`: generate the response of the user. | ||
| - `calculate_score`: calculate the score of the interaction. | ||
| - `finalize_interaction`: finalize the interaction instance. | ||
| """ | ||
|
|
||
| def __init__(self, config: dict): | ||
| super().__init__(config) | ||
| self._instance_dict = {} | ||
|
|
||
| async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: | ||
| async def start_interaction(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: | ||
| if instance_id is None: | ||
| instance_id = str(uuid4()) | ||
| self._instance_dict[instance_id] = { | ||
|
|
@@ -48,30 +49,30 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional | |
| } | ||
| return instance_id | ||
|
|
||
| async def get_feedback(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[str, float, dict]: | ||
| content = '' | ||
| async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[str, float, dict]: | ||
| content = "" | ||
| for i in range(len(messages) - 1, -1, -1): | ||
| item = messages[i] | ||
| if item.get('role') == 'user': | ||
| content = item.get('content') | ||
| if item.get("role") == "user": | ||
| content = item.get("content") | ||
| break | ||
|
|
||
| if content.startswith("#### "): | ||
| self._instance_dict[instance_id]["response"] = content | ||
| else: | ||
| self._instance_dict[instance_id]["response"] = "#### " + content | ||
|
|
||
| reward = await self.calc_reward(instance_id) | ||
| reward = await self.calculate_score(instance_id) | ||
| if reward == 1.0: | ||
| feedback = "Your response is correct!" | ||
|
||
| go_on = False | ||
| should_terminate_sequence = True | ||
| else: | ||
| feedback = "Your response is incorrect! You need to reflect on your answer and try again." | ||
| go_on = True | ||
| should_terminate_sequence = False | ||
|
|
||
| return f"{feedback=} {reward=}", go_on, {} | ||
| return should_terminate_sequence, f"{feedback=}", reward, {} | ||
|
|
||
| async def calc_reward(self, instance_id: str, **kwargs) -> float: | ||
| async def calculate_score(self, instance_id: str, **kwargs) -> float: | ||
| return gsm8k.compute_score( | ||
| self._instance_dict[instance_id]["response"], | ||
| self._instance_dict[instance_id]["ground_truth"], | ||
|
|
@@ -80,5 +81,5 @@ async def calc_reward(self, instance_id: str, **kwargs) -> float: | |
| score=1.0, | ||
| ) | ||
|
|
||
| async def release(self, instance_id: str, **kwargs) -> None: | ||
| async def finalize_interaction(self, instance_id: str, **kwargs) -> None: | ||
| del self._instance_dict[instance_id] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better support uuid by default like https://github.com/volcengine/verl/blob/main/verl/tools/base_tool.py#L50.