Skip to content

Commit 1852b82

Browse files
committed
[rollout] feat: Allow customization of async server class (verl-project#2326)
### What does this PR do? This PR contains two aspects: 1. Introduction of a new configuration option `actor_rollout_ref.rollout.custom_async_server` to allow users to customize the async server class. 2. Make `load_extern_type` more robust and support prefix like `pkg://` or `file://`, while non-breaking to any existing features and supported paths. Without this PR, it's impossible to use a customized version of AsyncvLLMServer in customized use case. We are currently using a set of ugly monkey patch to achieve this goal. Ultimately I believe `rollout.name` and `rollout.custom_async_server` can be combined. But `rollout.name` is currently referenced in too many places. It's quite difficult for me to handle all of them. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: [link](https://github.com/volcengine/verl/pulls?q=is%3Apr+is%3Aopen+async+server) - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test I have tested on our internal pipelines. The new patch works as expected and the old async servers still work as usual. ### API and Usage Example Our config is something like this: ```yaml hydra: searchpath: - pkg://verl/trainer/config defaults: - ppo_trainer - _self_ data: filter_overlong_prompts: false actor_rollout_ref: rollout: mode: async custom_async_server: path: pkg://mypackage.verl.async_server name: CustomizedvLLMServer ``` ### High-Level Design This PR is pretty straightforward. ### Specific Changes Update the docs. Update behavior in agent loop and async server manager. Update `load_extern_type` implementation. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: I think it's quite troublesome to add a CI for this feature. I can add one if you feel necessary. - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
1 parent 6f05ff0 commit 1852b82

File tree

6 files changed

+85
-26
lines changed

6 files changed

+85
-26
lines changed

docs/examples/config.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ Actor/Rollout/Reference Policy
194194
n: 1
195195
do_sample: False # default eager for validation
196196
197+
agent:
198+
custom_async_server: # Use custom async server implementation for rollout
199+
path: null
200+
name: null
201+
197202
**Common config for actor, rollout and reference model**
198203

199204
- ``actor_rollout_ref.hybrid_engine``: Whether it's a hybrid engine,

verl/experimental/agent_loop/agent_loop.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,14 @@ def _initialize_llm_servers(self):
341341
self.async_llm_servers = [None] * self.rollout_dp_size
342342
self.server_addresses = [None] * self.rollout_dp_size
343343

344-
server_class = async_server_class(
345-
rollout_backend=self.config.actor_rollout_ref.rollout.name,
346-
)
344+
if self.config.actor_rollout_ref.rollout.agent.custom_async_server:
345+
server_class = async_server_class(
346+
rollout_backend=self.config.actor_rollout_ref.rollout.name,
347+
rollout_backend_module=self.config.actor_rollout_ref.rollout.agent.custom_async_server.path,
348+
rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name,
349+
)
350+
else:
351+
server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name)
347352

348353
# Start all server instances, restart if address already in use.
349354
unready_dp_ranks = set(range(self.rollout_dp_size))

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,18 @@ actor_rollout_ref:
255255
# Number of agent loop workers
256256
num_workers: 8
257257

258+
custom_async_server:
259+
path: null
260+
name: null
261+
258262
# support logging rollout prob for debugging purpose
259263
calculate_log_probs: False
260264
# Nsight system profiler configs
261265
profiler:
262266
discrete: False
263267
all_ranks: False
264268
ranks: null
269+
265270
critic:
266271
rollout_n: ${actor_rollout_ref.rollout.n}
267272
strategy: ${actor_rollout_ref.actor.strategy}

verl/trainer/config/ppo_trainer.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,15 @@ actor_rollout_ref:
576576
# Number of agent loop workers
577577
num_workers: 8
578578

579+
# custom async server configs
580+
custom_async_server:
581+
582+
# Path to the custom async server implementation
583+
path: null
584+
585+
# Class name of the custom async server class (e.g. AsyncvLLMServer)
586+
name: null
587+
579588
# configs for the critic
580589
critic:
581590

verl/utils/import_utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,35 @@ def import_external_libs(external_libs=None):
7979

8080
def load_extern_type(file_path: Optional[str], type_name: Optional[str]):
8181
"""Load a external data type based on the file path and type name"""
82+
import importlib
8283
import importlib.util
8384
import os
8485

8586
if not file_path:
8687
return None
8788

88-
if not os.path.exists(file_path):
89-
raise FileNotFoundError(f"Custom type file '{file_path}' not found.")
90-
91-
spec = importlib.util.spec_from_file_location("custom_module", file_path)
92-
module = importlib.util.module_from_spec(spec)
93-
try:
94-
spec.loader.exec_module(module)
95-
except Exception as e:
96-
raise RuntimeError(f"Error loading module from '{file_path}'") from e
89+
if file_path.startswith("pkg://"):
90+
# pkg://verl.utils.dataset.rl_dataset
91+
# pkg://verl/utils/dataset/rl_dataset
92+
module_name = file_path[6:].replace("/", ".")
93+
module = importlib.import_module(module_name)
94+
95+
else:
96+
# file://verl/utils/dataset/rl_dataset
97+
# file:///path/to/verl/utils/dataset/rl_dataset.py
98+
# or without file:// prefix
99+
if file_path.startswith("file://"):
100+
file_path = file_path[7:]
101+
102+
if not os.path.exists(file_path):
103+
raise FileNotFoundError(f"Custom type file '{file_path}' not found.")
104+
105+
spec = importlib.util.spec_from_file_location("custom_module", file_path)
106+
module = importlib.util.module_from_spec(spec)
107+
try:
108+
spec.loader.exec_module(module)
109+
except Exception as e:
110+
raise RuntimeError(f"Error loading module from '{file_path}'") from e
97111

98112
if not hasattr(module, type_name):
99113
raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.")

verl/workers/rollout/async_server.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import threading
1919
from abc import ABC, abstractmethod
2020
from contextlib import asynccontextmanager
21-
from typing import Any, Dict, List, Tuple, Type
21+
from typing import Any, Dict, List, Optional, Tuple, Type
2222

2323
import fastapi
2424
import ray
@@ -135,9 +135,14 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup):
135135
self.async_llm_servers = [None] * self.rollout_dp_size
136136
self.server_addresses = [None] * self.rollout_dp_size
137137

138-
server_class = async_server_class(
139-
rollout_backend=self.config.rollout.name,
140-
)
138+
if self.config.rollout.agent.custom_async_server:
139+
server_class = async_server_class(
140+
rollout_backend=self.config.rollout.name,
141+
rollout_backend_module=self.config.rollout.agent.custom_async_server.path,
142+
rollout_backend_class=self.config.rollout.agent.custom_async_server.name,
143+
)
144+
else:
145+
server_class = async_server_class(rollout_backend=self.config.rollout.name)
141146

142147
# Start all server instances, restart if address already in use.
143148
unready_dp_ranks = set(range(self.rollout_dp_size))
@@ -233,22 +238,38 @@ def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto
233238
return future.result()
234239

235240

236-
def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]:
241+
def async_server_class(
242+
rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None
243+
) -> Type[AsyncServerBase]:
237244
"""Get async server class.
238245
239246
Args:
240-
rollout_backend: str, rollout backend, should be "vllm" or "sglang".
247+
rollout_backend: str, rollout backend type (alias), should be "vllm" or "sglang".
248+
rollout_backend_module: Optional[str], import path of the rollout backend.
249+
rollout_backend_class: Optional[str], class name of the rollout backend.
241250
242251
Returns:
243252
Type[AsyncServerBase]: async server class.
244253
"""
245-
if rollout_backend == "vllm":
246-
from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer
254+
if rollout_backend_class is None and rollout_backend_module is None:
255+
# If both are None, use the default backend class
256+
# Do not change the original import behavior
257+
# importlib.import_module and from ... import ... have subtle differences in ray
247258

248-
return AsyncvLLMServer
249-
elif rollout_backend == "sglang":
250-
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer
259+
if rollout_backend == "vllm":
260+
from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer
251261

252-
return AsyncSglangServer
253-
else:
254-
raise NotImplementedError
262+
return AsyncvLLMServer
263+
elif rollout_backend == "sglang":
264+
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer
265+
266+
return AsyncSglangServer
267+
else:
268+
raise NotImplementedError(f"rollout backend {rollout_backend} is not supported")
269+
270+
if rollout_backend_module is None or rollout_backend_class is None:
271+
raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization")
272+
273+
from verl.utils.import_utils import load_extern_type
274+
275+
return load_extern_type(rollout_backend_module, rollout_backend_class)

0 commit comments

Comments
 (0)