Skip to content
Merged
23 changes: 20 additions & 3 deletions verl/workers/sharding_manager/megatron_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@

from .base import BaseShardingManager

try:
# python >= 3.13
from itertools import batched
except ImportError:
from itertools import islice

def batched(iterable, n):
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch


logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))

Expand Down Expand Up @@ -130,21 +145,23 @@ async def update_weights(self, params):
# named_tensors = [(k, v) for k, v in params.items()]
named_tensors = params
load_format = None
for tensor_index, (name, tensor) in enumerate(named_tensors):
fetch_bs = 128
for batch in batched(named_tensors, fetch_bs):
if self.device_mesh["tp"].get_local_rank() == 0:
await self.inference_engine.update_weights_from_tensor(
named_tensors=[
(
name,
tensor.detach(),
)
for name, tensor in batch
],
load_format=load_format,
flush_cache=False,
)

if self.device_mesh["tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache()
if self.device_mesh["tp"].get_local_rank() == 0:
await self.inference_engine.flush_cache()

async def release_memory(self):
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
Expand Down