Skip to content

Commit 3f4a4f5

Browse files
chenhaiqhuangjunyi.0
authored andcommitted
[rollout] fix: error in sgyang async mode (verl-project#2098)
Fixed regression from: - verl-project#1668 - verl-project#1933 Added e2e test for both sglang and vllm async mode test
1 parent ec9a643 commit 3f4a4f5

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

.github/workflows/e2e_ppo_trainer.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,16 @@ jobs:
251251
run: |
252252
ray stop --force
253253
ENGINE=sglang bash tests/special_e2e/ppo_trainer/run_function_reward.sh
254+
- name: Running GSM8K E2E training tests on sglang async
255+
run: |
256+
ray stop --force
257+
ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
258+
- name: Running GSM8K E2E training tests on vllm async
259+
run: |
260+
ray stop --force
261+
export VLLM_USE_V1=1
262+
ray start --head
263+
ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
254264
255265
e2e_ppo_trainer_sglang_multiturn_with_tool:
256266
runs-on: [L20x8]

tests/special_e2e/ppo_trainer/run_function_reward.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512}
1414

1515
ENGINE=${ENGINE:-vllm}
1616
ROLLOUT_MODE=${ROLLOUT_MODE:-sync}
17+
18+
RETURN_RAW_CHAT="False"
19+
if [ "$ROLLOUT_MODE" = "async" ]; then
20+
RETURN_RAW_CHAT="True"
21+
fi
22+
1723
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8}
1824
ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False}
1925
ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}
@@ -84,6 +90,7 @@ python3 -m verl.trainer.main_ppo \
8490
data.train_batch_size="${train_prompt_bsz}" \
8591
data.max_prompt_length="${MAX_PROMPT_LEN}" \
8692
data.max_response_length="${MAX_RESPONSE_LEN}" \
93+
data.return_raw_chat=${RETURN_RAW_CHAT} \
8794
actor_rollout_ref.model.path="${MODEL_PATH}" \
8895
actor_rollout_ref.model.use_shm=${USE_SHM} \
8996
actor_rollout_ref.model.lora_rank=${LORA_RANK} \

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@
2929
class AsyncSglangServer(AsyncServerBase):
3030
def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str):
3131
super().__init__()
32-
self.config = config
33-
rollout_config = config.get("rollout", {})
34-
self._tp_size = rollout_config.get("tensor_model_parallel_size", 1)
32+
self.config = config.actor_rollout_ref
33+
self._tp_size = self.config.rollout.get("tensor_model_parallel_size", 1)
3534
self._dp_size = dp_size
3635
self._dp_rank = dp_rank
3736
self.wg_prefix = wg_prefix

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ async def chat_completion(self, json_request):
10421042
request_id=str(uuid4()),
10431043
state=AsyncRolloutRequestStateEnum.PENDING,
10441044
messages=[Message.model_validate(msg) for msg in json_request["messages"]],
1045-
tools=_tool_schemas,
1045+
tool_schemas=_tool_schemas,
10461046
tools_kwargs=_tools_kwargs,
10471047
input_ids=_input_ids,
10481048
prompt_ids=_input_ids,
@@ -1057,8 +1057,12 @@ async def chat_completion(self, json_request):
10571057
prompt_loss_mask=[0] * len(_input_ids),
10581058
response_loss_mask=[],
10591059
reward_scores={},
1060+
max_prompt_len=self.config.prompt_length,
10601061
max_response_len=self.config.response_length,
10611062
max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length),
1063+
use_inference_chat_template=self.config.multi_turn.use_inference_chat_template,
1064+
enable_tokenization_sanity_check=self.config.multi_turn.enable_tokenization_sanity_check,
1065+
tokenizer=self.tokenizer,
10621066
)
10631067

10641068
# json_request already contains sampling_params

0 commit comments

Comments
 (0)