Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: enable dynamic LoRA refresh for SGLang rollout
  • Loading branch information
JohnConnor123 committed Feb 6, 2026
commit c93593b6b4eb00403e22a16d5b9894c52418da00
18 changes: 15 additions & 3 deletions tests/workers/rollout/rollout_sglang/test_http_server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def test_make_request_success(self, mock_launch_server_process, basic_adapter_kw
mock_post.assert_called_with(
"http://localhost:8000/test_endpoint",
json={"param": "value"},
headers={"Content-Type": "application/json; charset=utf-8"},
timeout=adapter.timeout,
)

Expand All @@ -382,7 +383,11 @@ def test_make_request_get_method(self, mock_launch_server_process, basic_adapter
result = adapter._make_request("test_endpoint", method="GET")

assert result == {"data": "test"}
mock_get.assert_called_with("http://localhost:8000/test_endpoint", timeout=adapter.timeout)
mock_get.assert_called_with(
"http://localhost:8000/test_endpoint",
headers={"Content-Type": "application/json; charset=utf-8"},
timeout=adapter.timeout,
)

def test_make_request_non_master(self, mock_launch_server_process):
"""Test request from non-master node returns empty dict."""
Expand Down Expand Up @@ -757,7 +762,10 @@ async def test_make_async_request_success(self, mock_launch_server_process, basi

# Verify post was called
mock_session.post.assert_called_once_with(
"http://localhost:8000/test_endpoint", json={"param": "value"}, timeout=adapter.timeout
"http://localhost:8000/test_endpoint",
json={"param": "value"},
timeout=adapter.timeout,
headers={"Content-Type": "application/json; charset=utf-8"},
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -787,7 +795,11 @@ async def test_make_async_request_get_method(self, mock_launch_server_process, b

# Validate
assert result == {"data": "test"}
mock_session.get.assert_called_once_with("http://localhost:8000/test_endpoint", timeout=adapter.timeout)
mock_session.get.assert_called_once_with(
"http://localhost:8000/test_endpoint",
timeout=adapter.timeout,
headers={"Content-Type": "application/json; charset=utf-8"},
)

@pytest.mark.asyncio
async def test_make_async_request_non_master(self, mock_launch_server_process):
Expand Down
5 changes: 5 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ async def rollout_mode(self):
peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
if hasattr(peft_model, "peft_config"): # LoRA
peft_config = peft_model.peft_config.get("default", None)
# SGLang servers load base weights from `model_path` at launch. When using LoRA training,
# the base model weights are typically frozen, so we don't need to sync base weights via
# the rollout weight-update path. Mark base as "synced" to only collect LoRA params.
if self.config.rollout.get("name", None) == "sglang" and not self.base_sync_done:
self.base_sync_done = True
params = collect_lora_params(
module=self.actor_module_fsdp,
layered_summon=self.config.rollout.get("layered_summon", False),
Expand Down
4 changes: 4 additions & 0 deletions verl/workers/rollout/sglang_rollout/async_sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ async def generate(
# video_data=video_data,
}

engine_kwargs = (self.config.get("engine_kwargs", {}) or {}).get("sglang", {}) or {}
if engine_kwargs.get("enable_lora", False):
request["lora_path"] = f"verl_policy_{self.replica_rank}_{self.node_rank}"

if self.config.enable_rollout_routing_replay:
request.update({"return_routed_experts": True})

Expand Down
89 changes: 83 additions & 6 deletions verl/workers/rollout/sglang_rollout/http_server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@
DEFAULT_MAX_CONNECTIONS = 2000
DEFAULT_MAX_WAIT_TIME = 300.0

ADMIN_OPTIONAL_ENDPOINTS = {
"abort_request",
"flush_cache",
"load_lora_adapter",
"release_memory_occupation",
"resume_memory_occupation",
"unload_lora_adapter",
"update_weights_from_tensor",
"update_weights_from_distributed",
"update_weights_from_ipc",
}


def _read_response(response: requests.Response):
if response.status_code == 204 or not response.content:
Expand Down Expand Up @@ -287,6 +299,16 @@ def _register_with_router(self) -> None:
logger.error(f"Failed to register with router: {e}")
# Don't raise here - server can still work without router

def _get_request_headers(self, endpoint: str) -> dict[str, str]:
headers = {"Content-Type": "application/json; charset=utf-8"}
if endpoint in ADMIN_OPTIONAL_ENDPOINTS:
token = getattr(self.server_args, "admin_api_key", None) or getattr(self.server_args, "api_key", None)
else:
token = getattr(self.server_args, "api_key", None)
if token:
headers["Authorization"] = f"Bearer {token}"
return headers

def _make_request(
self,
endpoint: str,
Expand Down Expand Up @@ -319,13 +341,14 @@ def _make_request(
return {}

url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
headers = self._get_request_headers(endpoint)

for attempt in range(self.max_attempts):
try:
if method.upper() == "GET":
response = requests.get(url, timeout=self.timeout)
response = requests.get(url, headers=headers, timeout=self.timeout)
else:
response = requests.post(url, json=payload or {}, timeout=self.timeout)
response = requests.post(url, json=payload or {}, headers=headers, timeout=self.timeout)

response.raise_for_status()
return _read_response(response)
Expand All @@ -335,6 +358,12 @@ def _make_request(
except requests.exceptions.ConnectionError:
logger.warning(f"Connection error for {endpoint} (attempt {attempt + 1})")
except requests.exceptions.HTTPError as e:
status_code = getattr(getattr(e, "response", None), "status_code", None)
# SGLang returns 400 when unloading a non-existent adapter. This is expected in
# "refresh" flows where we optimistically unload before loading a new adapter.
if endpoint == "unload_lora_adapter" and status_code == 400:
logger.warning(f"HTTP 400 for unload_lora_adapter (likely not loaded yet): {e}")
return {}
logger.error(f"HTTP error for {endpoint}: {e}")
raise
except Exception as e:
Expand Down Expand Up @@ -389,6 +418,24 @@ def update_weights_from_tensor(self, req: UpdateWeightsFromTensorReqInput) -> di
},
)

def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False) -> dict[str, Any]:
return self._make_request(
"load_lora_adapter",
{
"lora_name": lora_name,
"lora_path": lora_path,
"pinned": pinned,
},
)

def unload_lora_adapter(self, lora_name: str) -> dict[str, Any]:
return self._make_request(
"unload_lora_adapter",
{
"lora_name": lora_name,
},
)

def shutdown(self) -> None:
"""Shutdown the HTTP server and clean up resources.

Expand Down Expand Up @@ -519,8 +566,11 @@ def flush_cache(self) -> dict[str, Any]:
# Use retry logic with limited attempts to avoid infinite loops
for attempt in range(self.max_attempts * 2): # Allow more retries for cache flush
try:
headers = self._get_request_headers("flush_cache")
response = requests.get(
f"http://{self.server_args.host}:{self.server_args.port}/flush_cache", timeout=self.timeout
f"http://{self.server_args.host}:{self.server_args.port}/flush_cache",
headers=headers,
timeout=self.timeout,
)
if response.status_code == 200:
return _read_response(response)
Expand Down Expand Up @@ -687,16 +737,19 @@ async def _make_async_request(
return {}

url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
headers = self._get_request_headers(endpoint)

for attempt in range(self.max_attempts):
try:
async with self._get_session() as session:
if method.upper() == "GET":
async with session.get(url, timeout=timeout) as response:
async with session.get(url, timeout=timeout, headers=headers) as response:
response.raise_for_status()
return await _read_async_response(response)
else:
async with session.post(url, json=payload or {}, timeout=timeout) as response:
async with session.post(
url, json=payload or {}, timeout=timeout, headers=headers
) as response:
response.raise_for_status()
return await _read_async_response(response)

Expand All @@ -705,6 +758,11 @@ async def _make_async_request(
except aiohttp.ClientConnectorError:
logger.warning(f"Connection error for {endpoint} (attempt {attempt + 1})")
except aiohttp.ClientResponseError as e:
# SGLang returns 400 when unloading a non-existent adapter. This is expected in
# "refresh" flows where we optimistically unload before loading a new adapter.
if endpoint == "unload_lora_adapter" and getattr(e, "status", None) == 400:
logger.warning(f"HTTP 400 for unload_lora_adapter (likely not loaded yet): {e}")
return {}
logger.error(f"HTTP error for {endpoint}: {e}")
raise
except Exception as e:
Expand Down Expand Up @@ -776,6 +834,24 @@ async def update_weights_from_tensor(
},
)

async def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False) -> dict[str, Any]:
return await self._make_async_request(
"load_lora_adapter",
{
"lora_name": lora_name,
"lora_path": lora_path,
"pinned": pinned,
},
)

async def unload_lora_adapter(self, lora_name: str) -> dict[str, Any]:
return await self._make_async_request(
"unload_lora_adapter",
{
"lora_name": lora_name,
},
)

async def flush_cache(self) -> dict[str, Any]:
"""Flush the cache of the server asynchronously.

Expand All @@ -796,9 +872,10 @@ async def flush_cache(self) -> dict[str, Any]:
# Use retry logic with limited attempts to avoid infinite loops
for attempt in range(self.max_attempts * 4): # Allow more retries for cache flush
try:
headers = self._get_request_headers("flush_cache")
async with self._get_session() as session:
url = f"http://{self.server_args.host}:{self.server_args.port}/flush_cache"
async with session.get(url) as response:
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await _read_async_response(response)
except Exception as e:
Expand Down
Loading