@@ -71,9 +71,7 @@ class AsyncRolloutRequest(BaseModel):
7171 request_id : str
7272 state : AsyncRolloutRequestStateEnum
7373 messages : List [Message ]
74- messages_dumps : List [Dict [str , Any ]]
7574 tool_schemas : Optional [List [OpenAIFunctionToolSchema ]] = None
76- tools : Optional [list [dict ]] = None
7775 tools_kwargs : Dict [str , Any ] = {}
7876 input_ids : List [int ]
7977 prompt_ids : List [int ]
@@ -108,60 +106,85 @@ def initialize_request(cls, values):
108106 raise ValueError ("tokenizer is required for AsyncRolloutRequest initialization" )
109107
110108 values ["messages" ] = [Message .model_validate (msg ) for msg in messages ]
111- values ["messages_dumps" ] = [msg .model_dump () for msg in values ["messages" ]]
112-
113- if tool_schemas := values .get ("tool_schemas" ):
114- tools = values ["tools" ] = [tool .model_dump () for tool in tool_schemas ]
115- tokens_without_prompt = tokenizer .apply_chat_template (messages , tools = tools , add_generation_prompt = False , tokenize = True )
116- tokenization_dict_with_prompt = tokenizer .apply_chat_template (messages , tools = tools , add_generation_prompt = True , tokenize = True , return_dict = True )
117- values ["input_ids" ] = values ["prompt_ids" ] = tokenization_dict_with_prompt ["input_ids" ]
118- values ["attention_mask" ] = values ["prompt_attention_mask" ] = tokenization_dict_with_prompt ["attention_mask" ]
109+
110+ tools = [tool .model_dump () for tool in tool_schemas ] if (tool_schemas := values .get ("tool_schemas" , [])) else None
111+ tokens_without_prompt = tokenizer .apply_chat_template (messages , tools = tools , add_generation_prompt = False , tokenize = True )
112+ if not values .get ("input_ids" ) or not values .get ("attention_mask" ):
113+ tokenization_dict_with_prompt = tokenizer .apply_chat_template (messages , tools = [tool .model_dump () for tool in tool_schemas ], add_generation_prompt = True , tokenize = True , return_dict = True )
114+ values ["input_ids" ], values ["attention_mask" ] = tokenization_dict_with_prompt ["input_ids" ], tokenization_dict_with_prompt ["attention_mask" ]
119115 if len (values ["input_ids" ]) > max_prompt_len :
120116 # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an error for this case in the future.
121117 logger .warning (f"Prompt { values ['batch_data_id' ]} length { len (values ['input_ids' ])} greater than max_prompt_len { max_prompt_len } after applied chat template with tools." )
122- elif not values .get ("input_ids" ) or not values .get ("attention_mask" ):
123- raise ValueError ("input_ids and attention_mask is required for requests without tools" )
124- else :
125- tokens_without_prompt = tokenizer .apply_chat_template (messages , add_generation_prompt = False , tokenize = True )
126- values ["prompt_ids" ], values ["prompt_attention_mask" ] = values ["input_ids" ], values ["attention_mask" ]
127118
119+ values ["prompt_ids" ], values ["prompt_attention_mask" ] = values ["input_ids" ], values ["attention_mask" ]
128120 values ["position_ids" ] = values ["prompt_position_ids" ] = compute_position_id_with_mask (torch .tensor (values ["attention_mask" ])).tolist ()
129121 values ["loss_mask" ] = values ["prompt_loss_mask" ] = [0 ] * len (values ["input_ids" ])
130122 values ["generation_prompt_ids" ] = values ["input_ids" ][len (tokens_without_prompt ) :]
131123 return values
132124
133125 def _update_input_ids (self , new_input_ids : List [int ], attention_mask : bool , loss_mask : bool , full_tokens : bool = False ) -> None :
126+ """
127+ Update the input_ids, attention_mask, position_ids, and loss_mask of the request.
128+ When full_tokens is True, it replaces the input_ids with new_input_ids and updates the attention_mask, position_ids, and loss_mask accordingly.
129+ When full_tokens is False, it appends new_input_ids to the input_ids and updates the attention_mask, position_ids, and loss_mask accordingly.
130+ """
134131 message_len_delta = (len (new_input_ids ) - len (self .input_ids )) if full_tokens else len (new_input_ids )
135132 self .input_ids = new_input_ids if full_tokens else (self .input_ids + new_input_ids )
136133 attention_mask = [int (attention_mask )] * message_len_delta
137134 self .attention_mask += attention_mask
138- _delta_position_ids = compute_position_id_with_mask (torch .tensor (attention_mask )).tolist ()
139- last_position_id = self .position_ids [- 1 ]
140- _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids ]
141135 self .loss_mask += [int (loss_mask )] * message_len_delta
142- self .position_ids += _position_ids
136+ self .position_ids += ( compute_position_id_with_mask ( torch . tensor ( attention_mask )) + ( self . position_ids [ - 1 ] + 1 )). tolist ()
143137
144138 assert len (self .input_ids ) == len (self .attention_mask ) == len (self .position_ids ) == len (self .loss_mask ), f"""Request { self .request_id } has different length of { len (self .input_ids )= } ,
145139 { len (self .attention_mask )= } , { len (self .position_ids )= } , { len (self .loss_mask )= } """
146140
147- def _append_messages (self , messages : list [Message ]) -> None :
148- self .messages .extend (messages )
149- self .messages_dumps .extend ([msg .model_dump () for msg in messages ])
150-
151- def _tokenize_all_messages (self , tokenizer : PreTrainedTokenizer , delta_input_ids_to_check : Optional [list [int ]], add_generation_prompt : bool = False ) -> None :
152- full_input_ids = tokenizer .apply_chat_template (self .messages_dumps , tools = self .tools , add_generation_prompt = add_generation_prompt , tokenize = True )
153- if self .tokenization_mode == "sanity_check" and delta_input_ids_to_check is not None :
154- assert full_input_ids == self .input_ids + delta_input_ids_to_check , (
155- f"Sanity check failed.\n Full tokenization result:\n { tokenizer .decode (full_input_ids , skip_special_tokens = False )} \n Fast tokenization result:\n { tokenizer .decode (self .input_ids + delta_input_ids_to_check , skip_special_tokens = False )} "
156- )
157- self ._update_input_ids (full_input_ids , attention_mask = True , loss_mask = False , full_tokens = True )
141+ def _fast_tokenize (self , tokenizer : PreTrainedTokenizer , num_messages : int , add_generation_prompt : bool , delta_tokens : Optional [List [int ]] = None ) -> list [int ]:
142+ """Fast tokenization tokenize the new messages only and append the tokens to the existing input_ids."""
143+
144+ # Handles cases where tool calls are incorrectly embedded, such as: I'll call the tool: <tool_call>{"name": ...}</tool_call>. Does this make sense?
145+ # The code below restructures the text and tool calls parsed by the SGLang tool parser using the chat template.
146+ # The outcome depends on the SGLang tool parser; for instance, with Qwen, any text after the first tool call is ignored.
147+ # TODO: Reconsider this approach for RL scenarios: 1. Try to parse as much valid response as possible; 2. Surface the error to the model for learning.
148+ if num_messages and (not delta_tokens or self .messages [- 1 ].tool_calls ):
149+ tools = [tool .model_dump () for tool in self .tool_schemas ] if self .tool_schemas else None
150+ content_start_pos = len (tokenizer .apply_chat_template ([msg .model_dump () for msg in self .messages [:- num_messages ]], tools = tools , add_generation_prompt = add_generation_prompt , tokenize = False ))
151+ content = tokenizer .apply_chat_template ([msg .model_dump () for msg in self .messages ], tools = tools , add_generation_prompt = False , tokenize = False )[content_start_pos :]
152+ delta_tokens = tokenizer .encode (content , add_special_tokens = False )
153+ return delta_tokens
154+
155+ def _full_tokenize (self , tokenizer : PreTrainedTokenizer , add_generation_prompt : bool ) -> list [int ]:
156+ """Full tokenization tokenizes the entire message history and returns the full tokenization result."""
157+ tools = [tool .model_dump () for tool in self .tool_schemas ] if self .tool_schemas else None
158+ return tokenizer .apply_chat_template ([msg .model_dump () for msg in self .messages ], tools = tools , add_generation_prompt = add_generation_prompt , tokenize = True )
159+
160+ def _tokenize_messages (self , tokenizer : PreTrainedTokenizer , num_messages : int , loss_mask : bool , add_generation_prompt : bool , delta_tokens : Optional [List [int ]] = None ) -> None :
161+ """
162+ Tokenizes messages and updates `input_ids`, `attention_mask`, `position_ids`, and `loss_mask` based on the selected tokenization mode.
158163
159- def get_prompt_ids (self , tokenizer : PreTrainedTokenizer ) -> list [int ]:
164+ :param num_messages: (Only used in "fast" mode) Specifies the number of most recent messages to tokenize.
165+ :param add_generation_prompt: (Only used in "full" mode) Indicates whether to include a generation prompt in the tokenized output.
166+ :param delta_tokens: (Only used in "fast" mode) Tokens to append to `input_ids`. If None, the method tokenizes the last `num_messages` messages.
167+ """
168+ match self .tokenization_mode :
169+ case "fast" :
170+ # Only when tokenizing assistant messages do we set loss_mask to True and exclude the generation prompt from token ids.
171+ # Therefore, only when loss_mask==True, we include the generation prompt in the calculation of the start position of new message tokens
172+ self ._update_input_ids (self ._fast_tokenize (tokenizer , num_messages , loss_mask , delta_tokens ), attention_mask = True , loss_mask = loss_mask )
173+ case "full" :
174+ self ._update_input_ids (self ._full_tokenize (tokenizer , add_generation_prompt ), attention_mask = True , loss_mask = loss_mask , full_tokens = True )
175+ case "sanity_check" :
176+ full_tokens = self ._full_tokenize (tokenizer , add_generation_prompt )
177+ delta_tokens = self ._fast_tokenize (tokenizer , num_messages , loss_mask , delta_tokens )
178+ assert full_tokens == self .input_ids + delta_tokens , f"Sanity check failed.\n Full tokenization result:\n { tokenizer .decode (full_tokens , skip_special_tokens = False )} \n Fast tokenization result:\n { tokenizer .decode (self .input_ids + delta_tokens , skip_special_tokens = False )} "
179+ self ._update_input_ids (full_tokens , attention_mask = True , loss_mask = loss_mask , full_tokens = True )
180+ case _:
181+ raise ValueError (f"Unsupported tokenization mode: { self .tokenization_mode } . Supported modes are 'fast', 'full', and 'sanity_check'." )
182+
183+ def get_generation_prompt_ids (self , tokenizer : PreTrainedTokenizer ) -> list [int ]:
160184 generation_prompt_ids = [] if self .input_ids [- len (self .generation_prompt_ids ) :] == self .generation_prompt_ids else self .generation_prompt_ids
161- if self .tokenization_mode == "fast" :
162- self ._update_input_ids (generation_prompt_ids , attention_mask = True , loss_mask = False )
163- else :
164- self ._tokenize_all_messages (tokenizer , generation_prompt_ids , add_generation_prompt = True )
185+ if not generation_prompt_ids :
186+ return self .input_ids
187+ self ._tokenize_messages (tokenizer , num_messages = 0 , loss_mask = False , add_generation_prompt = True , delta_tokens = generation_prompt_ids )
165188 return self .input_ids
166189
167190 def add_assistant_message (
@@ -171,40 +194,14 @@ def add_assistant_message(
171194 content_ids : Optional [List [int ]] = None ,
172195 tool_calls : Optional [List [OpenAIFunctionToolCall ]] = None ,
173196 ) -> None :
174- self ._append_messages ([Message (role = "assistant" , content = content , tool_calls = tool_calls )])
175-
176- if self .tokenization_mode != "full" :
177- if tool_calls or not content_ids :
178- # Handles cases where tool calls are incorrectly embedded, such as: I'll call the tool: <tool_call>{"name": ...}</tool_call>. Does this make sense?
179- # The code below restructures the text and tool calls parsed by the SGLang tool parser using the chat template.
180- # The outcome depends on the SGLang tool parser; for instance, with Qwen, any text after the first tool call is ignored.
181- # TODO: Reconsider this approach for RL scenarios: 1. Try to parse as much valid response as possible; 2. Surface the error to the model for learning.
182- content_start_pos = len (tokenizer .apply_chat_template (self .messages_dumps [:- 1 ], tools = self .tools , add_generation_prompt = True , tokenize = False ))
183- content = tokenizer .apply_chat_template (self .messages_dumps , tools = self .tools , add_generation_prompt = False , tokenize = False )[content_start_pos :]
184- content_ids = tokenizer .encode (content , add_special_tokens = False )
185-
186- if self .tokenization_mode == "fast" :
187- self ._update_input_ids (content_ids , attention_mask = True , loss_mask = True )
188- return
189-
190- self ._tokenize_all_messages (tokenizer , content_ids )
197+ self .messages .append (Message (role = "assistant" , content = content , tool_calls = tool_calls ))
198+ self ._tokenize_messages (tokenizer , num_messages = 1 , loss_mask = True , add_generation_prompt = False , delta_tokens = content_ids )
191199
192200 def add_tool_response_messages (self , tokenizer : PreTrainedTokenizer , contents : list [str ]) -> None :
193201 if not contents :
194202 return
195-
196- self ._append_messages ([Message (role = "tool" , content = content ) for content in contents ])
197- response_token_ids = None
198- if self .tokenization_mode != "full" :
199- response_start_pos = len (tokenizer .apply_chat_template (self .messages_dumps [: - len (contents )], tools = self .tools , add_generation_prompt = False , tokenize = False ))
200- response_tokens = tokenizer .apply_chat_template (self .messages_dumps , tools = self .tools , add_generation_prompt = False , tokenize = False )[response_start_pos :]
201- response_token_ids = tokenizer .encode (response_tokens , add_special_tokens = False )
202-
203- if self .tokenization_mode == "fast" :
204- self ._update_input_ids (response_token_ids , attention_mask = True , loss_mask = False )
205- return
206-
207- self ._tokenize_all_messages (tokenizer , response_token_ids )
203+ self .messages .extend ([Message (role = "tool" , content = content ) for content in contents ])
204+ self ._tokenize_messages (tokenizer , num_messages = len (contents ), loss_mask = False , add_generation_prompt = False )
208205
209206 def update_metrics (self , metrics : Any , tool_id : str ) -> None :
210207 """
0 commit comments