Skip to content

Commit e48c64c

Browse files
committed
Address comments
1 parent 4059a4f commit e48c64c

File tree

4 files changed

+64
-69
lines changed

4 files changed

+64
-69
lines changed

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ actor_rollout_ref:
171171
enable: False # should set rollout.name to sglang_async if True
172172
max_turns: null # null for no limit (default max_length // 3)
173173
tool_config_path: null # null for no tool
174-
format: chatml # deprecated, no effect anymore
175174
# choose from: fast, full, sanity_check
176175
# fast: only tokenizes the new messages in each turn
177176
# full: tokenize the whole conversation in each turn. This could be significantly slower for long conversations but guarantees tokenization consistency across turns

verl/trainer/config/ppo_trainer.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ actor_rollout_ref:
143143
enable: False # should set rollout.name to sglang_async if True
144144
max_turns: null # null for no limit (default max_length // 3)
145145
tool_config_path: null # null for no tool
146-
format: chatml # deprecated, no effect anymore
147146
# choose from: fast, full, sanity_check
148147
# fast: only tokenizes the new messages in each turn
149148
# full: tokenize the whole conversation in each turn. This could be significantly slower for long conversations but guarantees tokenization consistency across turns

verl/workers/rollout/schemas.py

Lines changed: 62 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ class AsyncRolloutRequest(BaseModel):
7171
request_id: str
7272
state: AsyncRolloutRequestStateEnum
7373
messages: List[Message]
74-
messages_dumps: List[Dict[str, Any]]
7574
tool_schemas: Optional[List[OpenAIFunctionToolSchema]] = None
76-
tools: Optional[list[dict]] = None
7775
tools_kwargs: Dict[str, Any] = {}
7876
input_ids: List[int]
7977
prompt_ids: List[int]
@@ -108,60 +106,85 @@ def initialize_request(cls, values):
108106
raise ValueError("tokenizer is required for AsyncRolloutRequest initialization")
109107

110108
values["messages"] = [Message.model_validate(msg) for msg in messages]
111-
values["messages_dumps"] = [msg.model_dump() for msg in values["messages"]]
112-
113-
if tool_schemas := values.get("tool_schemas"):
114-
tools = values["tools"] = [tool.model_dump() for tool in tool_schemas]
115-
tokens_without_prompt = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=False, tokenize=True)
116-
tokenization_dict_with_prompt = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=True, tokenize=True, return_dict=True)
117-
values["input_ids"] = values["prompt_ids"] = tokenization_dict_with_prompt["input_ids"]
118-
values["attention_mask"] = values["prompt_attention_mask"] = tokenization_dict_with_prompt["attention_mask"]
109+
110+
tools = [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None
111+
tokens_without_prompt = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=False, tokenize=True)
112+
if not values.get("input_ids") or not values.get("attention_mask"):
113+
tokenization_dict_with_prompt = tokenizer.apply_chat_template(messages, tools=[tool.model_dump() for tool in tool_schemas], add_generation_prompt=True, tokenize=True, return_dict=True)
114+
values["input_ids"], values["attention_mask"] = tokenization_dict_with_prompt["input_ids"], tokenization_dict_with_prompt["attention_mask"]
119115
if len(values["input_ids"]) > max_prompt_len:
120116
# Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an error for this case in the future.
121117
logger.warning(f"Prompt {values['batch_data_id']} length {len(values['input_ids'])} greater than max_prompt_len {max_prompt_len} after applied chat template with tools.")
122-
elif not values.get("input_ids") or not values.get("attention_mask"):
123-
raise ValueError("input_ids and attention_mask is required for requests without tools")
124-
else:
125-
tokens_without_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=True)
126-
values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"]
127118

119+
values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"]
128120
values["position_ids"] = values["prompt_position_ids"] = compute_position_id_with_mask(torch.tensor(values["attention_mask"])).tolist()
129121
values["loss_mask"] = values["prompt_loss_mask"] = [0] * len(values["input_ids"])
130122
values["generation_prompt_ids"] = values["input_ids"][len(tokens_without_prompt) :]
131123
return values
132124

133125
def _update_input_ids(self, new_input_ids: List[int], attention_mask: bool, loss_mask: bool, full_tokens: bool = False) -> None:
126+
"""
127+
Update the input_ids, attention_mask, position_ids, and loss_mask of the request.
128+
When full_tokens is True, it replaces the input_ids with new_input_ids and updates the attention_mask, position_ids, and loss_mask accordingly.
129+
When full_tokens is False, it appends new_input_ids to the input_ids and updates the attention_mask, position_ids, and loss_mask accordingly.
130+
"""
134131
message_len_delta = (len(new_input_ids) - len(self.input_ids)) if full_tokens else len(new_input_ids)
135132
self.input_ids = new_input_ids if full_tokens else (self.input_ids + new_input_ids)
136133
attention_mask = [int(attention_mask)] * message_len_delta
137134
self.attention_mask += attention_mask
138-
_delta_position_ids = compute_position_id_with_mask(torch.tensor(attention_mask)).tolist()
139-
last_position_id = self.position_ids[-1]
140-
_position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids]
141135
self.loss_mask += [int(loss_mask)] * message_len_delta
142-
self.position_ids += _position_ids
136+
self.position_ids += (compute_position_id_with_mask(torch.tensor(attention_mask)) + (self.position_ids[-1] + 1)).tolist()
143137

144138
assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=},
145139
{len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}"""
146140

147-
def _append_messages(self, messages: list[Message]) -> None:
148-
self.messages.extend(messages)
149-
self.messages_dumps.extend([msg.model_dump() for msg in messages])
150-
151-
def _tokenize_all_messages(self, tokenizer: PreTrainedTokenizer, delta_input_ids_to_check: Optional[list[int]], add_generation_prompt: bool = False) -> None:
152-
full_input_ids = tokenizer.apply_chat_template(self.messages_dumps, tools=self.tools, add_generation_prompt=add_generation_prompt, tokenize=True)
153-
if self.tokenization_mode == "sanity_check" and delta_input_ids_to_check is not None:
154-
assert full_input_ids == self.input_ids + delta_input_ids_to_check, (
155-
f"Sanity check failed.\nFull tokenization result:\n{tokenizer.decode(full_input_ids, skip_special_tokens=False)}\nFast tokenization result:\n{tokenizer.decode(self.input_ids + delta_input_ids_to_check, skip_special_tokens=False)}"
156-
)
157-
self._update_input_ids(full_input_ids, attention_mask=True, loss_mask=False, full_tokens=True)
141+
def _fast_tokenize(self, tokenizer: PreTrainedTokenizer, num_messages: int, add_generation_prompt: bool, delta_tokens: Optional[List[int]] = None) -> list[int]:
142+
"""Fast tokenization tokenize the new messages only and append the tokens to the existing input_ids."""
143+
144+
# Handles cases where tool calls are incorrectly embedded, such as: I'll call the tool: <tool_call>{"name": ...}</tool_call>. Does this make sense?
145+
# The code below restructures the text and tool calls parsed by the SGLang tool parser using the chat template.
146+
# The outcome depends on the SGLang tool parser; for instance, with Qwen, any text after the first tool call is ignored.
147+
# TODO: Reconsider this approach for RL scenarios: 1. Try to parse as much valid response as possible; 2. Surface the error to the model for learning.
148+
if num_messages and (not delta_tokens or self.messages[-1].tool_calls):
149+
tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None
150+
content_start_pos = len(tokenizer.apply_chat_template([msg.model_dump() for msg in self.messages[:-num_messages]], tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False))
151+
content = tokenizer.apply_chat_template([msg.model_dump() for msg in self.messages], tools=tools, add_generation_prompt=False, tokenize=False)[content_start_pos:]
152+
delta_tokens = tokenizer.encode(content, add_special_tokens=False)
153+
return delta_tokens
154+
155+
def _full_tokenize(self, tokenizer: PreTrainedTokenizer, add_generation_prompt: bool) -> list[int]:
156+
"""Full tokenization tokenizes the entire message history and returns the full tokenization result."""
157+
tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None
158+
return tokenizer.apply_chat_template([msg.model_dump() for msg in self.messages], tools=tools, add_generation_prompt=add_generation_prompt, tokenize=True)
159+
160+
def _tokenize_messages(self, tokenizer: PreTrainedTokenizer, num_messages: int, loss_mask: bool, add_generation_prompt: bool, delta_tokens: Optional[List[int]] = None) -> None:
161+
"""
162+
Tokenizes messages and updates `input_ids`, `attention_mask`, `position_ids`, and `loss_mask` based on the selected tokenization mode.
158163
159-
def get_prompt_ids(self, tokenizer: PreTrainedTokenizer) -> list[int]:
164+
:param num_messages: (Only used in "fast" mode) Specifies the number of most recent messages to tokenize.
165+
:param add_generation_prompt: (Only used in "full" mode) Indicates whether to include a generation prompt in the tokenized output.
166+
:param delta_tokens: (Only used in "fast" mode) Tokens to append to `input_ids`. If None, the method tokenizes the last `num_messages` messages.
167+
"""
168+
match self.tokenization_mode:
169+
case "fast":
170+
# Only when tokenizing assistant messages do we set loss_mask to True and exclude the generation prompt from token ids.
171+
# Therefore, only when loss_mask==True, we include the generation prompt in the calculation of the start position of new message tokens
172+
self._update_input_ids(self._fast_tokenize(tokenizer, num_messages, loss_mask, delta_tokens), attention_mask=True, loss_mask=loss_mask)
173+
case "full":
174+
self._update_input_ids(self._full_tokenize(tokenizer, add_generation_prompt), attention_mask=True, loss_mask=loss_mask, full_tokens=True)
175+
case "sanity_check":
176+
full_tokens = self._full_tokenize(tokenizer, add_generation_prompt)
177+
delta_tokens = self._fast_tokenize(tokenizer, num_messages, loss_mask, delta_tokens)
178+
assert full_tokens == self.input_ids + delta_tokens, f"Sanity check failed.\nFull tokenization result:\n{tokenizer.decode(full_tokens, skip_special_tokens=False)}\nFast tokenization result:\n{tokenizer.decode(self.input_ids + delta_tokens, skip_special_tokens=False)}"
179+
self._update_input_ids(full_tokens, attention_mask=True, loss_mask=loss_mask, full_tokens=True)
180+
case _:
181+
raise ValueError(f"Unsupported tokenization mode: {self.tokenization_mode}. Supported modes are 'fast', 'full', and 'sanity_check'.")
182+
183+
def get_generation_prompt_ids(self, tokenizer: PreTrainedTokenizer) -> list[int]:
160184
generation_prompt_ids = [] if self.input_ids[-len(self.generation_prompt_ids) :] == self.generation_prompt_ids else self.generation_prompt_ids
161-
if self.tokenization_mode == "fast":
162-
self._update_input_ids(generation_prompt_ids, attention_mask=True, loss_mask=False)
163-
else:
164-
self._tokenize_all_messages(tokenizer, generation_prompt_ids, add_generation_prompt=True)
185+
if not generation_prompt_ids:
186+
return self.input_ids
187+
self._tokenize_messages(tokenizer, num_messages=0, loss_mask=False, add_generation_prompt=True, delta_tokens=generation_prompt_ids)
165188
return self.input_ids
166189

167190
def add_assistant_message(
@@ -171,40 +194,14 @@ def add_assistant_message(
171194
content_ids: Optional[List[int]] = None,
172195
tool_calls: Optional[List[OpenAIFunctionToolCall]] = None,
173196
) -> None:
174-
self._append_messages([Message(role="assistant", content=content, tool_calls=tool_calls)])
175-
176-
if self.tokenization_mode != "full":
177-
if tool_calls or not content_ids:
178-
# Handles cases where tool calls are incorrectly embedded, such as: I'll call the tool: <tool_call>{"name": ...}</tool_call>. Does this make sense?
179-
# The code below restructures the text and tool calls parsed by the SGLang tool parser using the chat template.
180-
# The outcome depends on the SGLang tool parser; for instance, with Qwen, any text after the first tool call is ignored.
181-
# TODO: Reconsider this approach for RL scenarios: 1. Try to parse as much valid response as possible; 2. Surface the error to the model for learning.
182-
content_start_pos = len(tokenizer.apply_chat_template(self.messages_dumps[:-1], tools=self.tools, add_generation_prompt=True, tokenize=False))
183-
content = tokenizer.apply_chat_template(self.messages_dumps, tools=self.tools, add_generation_prompt=False, tokenize=False)[content_start_pos:]
184-
content_ids = tokenizer.encode(content, add_special_tokens=False)
185-
186-
if self.tokenization_mode == "fast":
187-
self._update_input_ids(content_ids, attention_mask=True, loss_mask=True)
188-
return
189-
190-
self._tokenize_all_messages(tokenizer, content_ids)
197+
self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls))
198+
self._tokenize_messages(tokenizer, num_messages=1, loss_mask=True, add_generation_prompt=False, delta_tokens=content_ids)
191199

192200
def add_tool_response_messages(self, tokenizer: PreTrainedTokenizer, contents: list[str]) -> None:
193201
if not contents:
194202
return
195-
196-
self._append_messages([Message(role="tool", content=content) for content in contents])
197-
response_token_ids = None
198-
if self.tokenization_mode != "full":
199-
response_start_pos = len(tokenizer.apply_chat_template(self.messages_dumps[: -len(contents)], tools=self.tools, add_generation_prompt=False, tokenize=False))
200-
response_tokens = tokenizer.apply_chat_template(self.messages_dumps, tools=self.tools, add_generation_prompt=False, tokenize=False)[response_start_pos:]
201-
response_token_ids = tokenizer.encode(response_tokens, add_special_tokens=False)
202-
203-
if self.tokenization_mode == "fast":
204-
self._update_input_ids(response_token_ids, attention_mask=True, loss_mask=False)
205-
return
206-
207-
self._tokenize_all_messages(tokenizer, response_token_ids)
203+
self.messages.extend([Message(role="tool", content=content) for content in contents])
204+
self._tokenize_messages(tokenizer, num_messages=len(contents), loss_mask=False, add_generation_prompt=False)
208205

209206
def update_metrics(self, metrics: Any, tool_id: str) -> None:
210207
"""

verl/workers/rollout/sglang_rollout/async_sglang_rollout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def _verify_config(self, model_hf_config):
160160
if self.config.multi_turn.max_turns is None:
161161
self.config.multi_turn.max_turns = self.config.max_model_len // 3
162162

163-
assert self.config.multi_turn.fast_tokenization in {"enable", "disable", "sanity_check"}, f"fast_tokenization should be one of [enable, disable, sanity_check], but got {self.config.multi_turn.fast_tokenization}"
164-
self.fast_tokenization_enabled = self.config.multi_turn.fast_tokenization == "enable"
163+
assert self.config.multi_turn.tokenization_mode in {"fast", "full", "sanity_check"}, f"tokenization_mode should be one of [fast, full, sanity_check], but got {self.config.multi_turn.tokenization_mode}"
164+
self.tokenization_mode = self.config.multi_turn.tokenization_mode
165165

166166
def _init_inference_engine(self, trust_remote_code, actor_module, port):
167167
# initialize the inference engine

0 commit comments

Comments
 (0)