Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ Actor/Rollout/Reference Policy
n: 1
do_sample: False # default eager for validation

agent:
custom_async_server: # Use custom async server implementation for rollout
path: null
name: null

**Common config for actor, rollout and reference model**

- ``actor_rollout_ref.hybrid_engine``: Whether it's a hybrid engine,
Expand Down
11 changes: 8 additions & 3 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,14 @@ def _initialize_llm_servers(self):
self.async_llm_servers = [None] * self.rollout_dp_size
self.server_addresses = [None] * self.rollout_dp_size

server_class = async_server_class(
rollout_backend=self.config.actor_rollout_ref.rollout.name,
)
if self.config.actor_rollout_ref.rollout.agent.custom_async_server:
server_class = async_server_class(
rollout_backend=self.config.actor_rollout_ref.rollout.name,
rollout_backend_module=self.config.actor_rollout_ref.rollout.agent.custom_async_server.path,
rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name,
)
else:
server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name)

# Start all server instances, restart if address already in use.
unready_dp_ranks = set(range(self.rollout_dp_size))
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,18 @@ actor_rollout_ref:
# Number of agent loop workers
num_workers: 8

custom_async_server:
path: null
name: null

# support logging rollout prob for debugging purpose
calculate_log_probs: False
# Nsight system profiler configs
profiler:
discrete: False
all_ranks: False
ranks: null

critic:
rollout_n: ${actor_rollout_ref.rollout.n}
strategy: ${actor_rollout_ref.actor.strategy}
Expand Down
9 changes: 9 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,15 @@ actor_rollout_ref:
# Number of agent loop workers
num_workers: 8

# custom async server configs
custom_async_server:

# Path to the custom async server implementation
path: null

# Class name of the custom async server class (e.g. AsyncvLLMServer)
name: null

# configs for the critic
critic:

Expand Down
32 changes: 23 additions & 9 deletions verl/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,35 @@ def import_external_libs(external_libs=None):

def load_extern_type(file_path: Optional[str], type_name: Optional[str]):
"""Load a external data type based on the file path and type name"""
import importlib
import importlib.util
import os

if not file_path:
return None

if not os.path.exists(file_path):
raise FileNotFoundError(f"Custom type file '{file_path}' not found.")

spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}'") from e
if file_path.startswith("pkg://"):
# pkg://verl.utils.dataset.rl_dataset
# pkg://verl/utils/dataset/rl_dataset
module_name = file_path[6:].replace("/", ".")
module = importlib.import_module(module_name)

else:
# file://verl/utils/dataset/rl_dataset
# file:///path/to/verl/utils/dataset/rl_dataset.py
# or without file:// prefix
if file_path.startswith("file://"):
file_path = file_path[7:]

if not os.path.exists(file_path):
raise FileNotFoundError(f"Custom type file '{file_path}' not found.")

spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}'") from e

if not hasattr(module, type_name):
raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.")
Expand Down
49 changes: 35 additions & 14 deletions verl/workers/rollout/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import threading
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import fastapi
import ray
Expand Down Expand Up @@ -135,9 +135,14 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup):
self.async_llm_servers = [None] * self.rollout_dp_size
self.server_addresses = [None] * self.rollout_dp_size

server_class = async_server_class(
rollout_backend=self.config.rollout.name,
)
if self.config.rollout.agent.custom_async_server:
server_class = async_server_class(
rollout_backend=self.config.rollout.name,
rollout_backend_module=self.config.rollout.agent.custom_async_server.path,
rollout_backend_class=self.config.rollout.agent.custom_async_server.name,
)
else:
server_class = async_server_class(rollout_backend=self.config.rollout.name)

# Start all server instances, restart if address already in use.
unready_dp_ranks = set(range(self.rollout_dp_size))
Expand Down Expand Up @@ -233,22 +238,38 @@ def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto
return future.result()


def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]:
def async_server_class(
rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None
) -> Type[AsyncServerBase]:
"""Get async server class.

Args:
rollout_backend: str, rollout backend, should be "vllm" or "sglang".
rollout_backend: str, rollout backend type (alias), should be "vllm" or "sglang".
rollout_backend_module: Optional[str], import path of the rollout backend.
rollout_backend_class: Optional[str], class name of the rollout backend.

Returns:
Type[AsyncServerBase]: async server class.
"""
if rollout_backend == "vllm":
from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer
if rollout_backend_class is None and rollout_backend_module is None:
# If both are None, use the default backend class
# Do not change the original import behavior
# importlib.import_module and from ... import ... have subtle differences in ray

return AsyncvLLMServer
elif rollout_backend == "sglang":
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer
if rollout_backend == "vllm":
from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer

return AsyncSglangServer
else:
raise NotImplementedError
return AsyncvLLMServer
elif rollout_backend == "sglang":
from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer

return AsyncSglangServer
else:
raise NotImplementedError(f"rollout backend {rollout_backend} is not supported")

if rollout_backend_module is None or rollout_backend_class is None:
raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization")

from verl.utils.import_utils import load_extern_type

return load_extern_type(rollout_backend_module, rollout_backend_class)
Loading