From 3f62182d099f9fec1f0b2a19636b801345aa6651 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Mon, 2 Feb 2026 17:58:39 +0800 Subject: [PATCH 1/2] [ckpt] feat: add mooncake backend --- .../test_correctness_on_npu.py | 48 ++++ verl/checkpoint_engine/__init__.py | 7 + .../mooncake_checkpoint_engine.py | 232 ++++++++++++++++++ 3 files changed, 287 insertions(+) create mode 100644 verl/checkpoint_engine/mooncake_checkpoint_engine.py diff --git a/tests/checkpoint_engine/test_correctness_on_npu.py b/tests/checkpoint_engine/test_correctness_on_npu.py index b99fcc771be..d4cddd3b8c8 100644 --- a/tests/checkpoint_engine/test_correctness_on_npu.py +++ b/tests/checkpoint_engine/test_correctness_on_npu.py @@ -74,6 +74,54 @@ async def test_hccl_checkpoint_engine( ray.shutdown() +@pytest.mark.skip(reason="temporary skip since our ci environment is not ready") +@pytest.mark.asyncio +@pytest.mark.parametrize("device", ["npu"]) +@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +async def test_mooncake_checkpoint_engine( + rebuild_group, + num_trainer, + num_rollout, + device, + num_nodes=1, + num_gpus_per_node=8, + check_allclose=True, + model_path="~/models/Qwen/Qwen3-8B-Base", +): + model_path = os.path.expanduser(model_path) + ray.init( + runtime_env={ + "env_vars": { + "ASCEND_USE_SHORT_CONNECTION": "1", + "VERL_LOGGING_LEVEL": "DEBUG", + } + } + ) + + # initialize config + checkpoint_engine_config = CheckpointEngineConfig( + backend="mooncake", engine_kwargs={"mooncake": {"device": device, "rebuild_group": rebuild_group}} + ) + model_config = HFModelConfig(path=model_path, use_remove_padding=True) + rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config) + + # create trainer and rollout worker group + resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) + resource_pool.get_placement_groups(device_name=get_device_name()) + trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) + trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config) + trainer.reset() + rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose) + + # create checkpoint engine manager + checkpoint_manager = CheckpointEngineManager(backend="mooncake", trainer=trainer, replicas=replicas) + for _ in range(3): + await checkpoint_manager.update_weights() + rollout.check_weights() + + ray.shutdown() + + if __name__ == "__main__": test_hccl_checkpoint_engine( rebuild_group=False, diff --git a/verl/checkpoint_engine/__init__.py b/verl/checkpoint_engine/__init__.py index 4409369e8e8..c4afdb17645 100644 --- a/verl/checkpoint_engine/__init__.py +++ b/verl/checkpoint_engine/__init__.py @@ -51,3 +51,10 @@ __all__ += ["NIXLCheckpointEngine"] except ImportError: NIXLCheckpointEngine = None + +try: + from .mooncake_checkpoint_engine import MooncakeCheckpointEngine + + __all__ += ["MoonCakeCheckpointEngine"] +except ImportError: + MooncakeCheckpointEngine = None diff --git a/verl/checkpoint_engine/mooncake_checkpoint_engine.py b/verl/checkpoint_engine/mooncake_checkpoint_engine.py new file mode 100644 index 00000000000..b446e2e7342 --- /dev/null +++ b/verl/checkpoint_engine/mooncake_checkpoint_engine.py @@ -0,0 +1,232 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import time +import gc +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Generator + +import ray +import torch +from vllm.distributed.utils import StatelessProcessGroup +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address +from verl.utils.device import get_torch_device + +from mooncake.engine import TransferEngine + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +@CheckpointEngineRegistry.register("mooncake") +class MooncakeCheckpointEngine(CheckpointEngine): + """Mooncake checkpoint engine with p2p communication using TransferEngine + + Args: + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. + device (str): The device to use for the checkpoint engine, "cpu" or "cuda". + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. + device_name (str): Mooncake device name filter. + """ + + def __init__( + self, + bucket_size: int, + device: str = "cuda", + rollout_dtype: torch.dtype = torch.bfloat16, + device_name: str = "", + is_master: bool = False, + rebuild_group: bool = False, + ): + self.bucket_size = bucket_size + self.device = device + self.rollout_dtype = rollout_dtype + self.is_master = is_master + self.rebuild_group = rebuild_group + + rank = int(os.environ["RANK"]) + device_count = get_torch_device().device_count() + local_rank = rank % device_count + get_torch_device().set_device(local_rank) + + self.engine = TransferEngine() + hostname = ray.util.get_node_ip_address().strip("[]") + ret = self.engine.initialize( + hostname, + "P2PHANDSHAKE", + "ascend_direct" if self.device == "npu" else "rdma", + device_name, + ) + assert ret == 0, f"TransferEngine initialize failed ret={ret}" + + rpc_port = self.engine.get_rpc_port() + self.session_id = f"{hostname}:{rpc_port}" + self.hostname = hostname + + self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device) + assert self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) == 0, "register_memory failed" + + def prepare(self) -> dict[str, Any]: + """Prepare send and recv buckets""" + # self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device) + # self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) + port, _ = get_free_port(self.hostname) + return {"addr": self.hostname, "port": port} + + @classmethod + def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadatas: list[dict]): + trainer_kwargs = { + "rank": [0] + [-1] * (trainer_world_size - 1), + "world_size": [rollout_world_size + 1] * trainer_world_size, + "metadata": [metadatas[0]] * trainer_world_size, + } + rollout_kwargs = { + "rank": list(range(1, rollout_world_size + 1)), + "world_size": [rollout_world_size + 1] * rollout_world_size, + "metadata": [metadatas[0]] * rollout_world_size, + } + return trainer_kwargs, rollout_kwargs + + def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any]): + self.rank = rank + self.world_size = world_size + if rank < 0: + return + + self.store = StatelessProcessGroup.create( + host=metadata["addr"], + port=metadata["port"], + rank=rank, + world_size=world_size, + ) + + if self.is_master: + buffer_info = { + "session_id": self.session_id, + "ptr": self.buf.data_ptr(), + "len": self.bucket_size, + } + self.store.broadcast_obj(obj=buffer_info, src=0) + else: + self.buffer_info = self.store.broadcast_obj(obj=None, src=0) + + + def finalize(self): + """Cleanup communication and deregister memory""" + self.store = None + get_torch_device().empty_cache() + gc.collect() + + async def wait_for_complete(self): + magic = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) + target = magic.repeat(self.world_size - 1) + while True: + if torch.equal(self.buf[4:4 * self.world_size], target): + break + await asyncio.sleep(0) + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send weights using Mooncake TransferEngine""" + start_time = time.time() + bucket_meta: dict[str, TensorMeta] = {} + offset = 0 + + for name, weight in weights: + if self.rank != 0: + continue + weight = weight.to(self.rollout_dtype) + + if offset + weight.nbytes > self.bucket_size: + get_torch_device().synchronize + info = { + "bucket_meta": bucket_meta, + "len": offset, + "is_last": False, + } + self.store.broadcast_obj(obj=info, src=0) + await self.wait_for_complete() + bucket_meta = {} + offset = 0 + + assert offset + weight.nbytes <= self.bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + ) + + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + self.buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight.nbytes + + if self.rank != 0: + return + + get_torch_device().synchronize() + info = { + "bucket_meta": bucket_meta, + "len": offset, + "is_last": True, + } + self.store.broadcast_obj(obj=info, src=0) + await self.wait_for_complete() + logger.info(f"send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive weights using Mooncake TransferEngine""" + start_time = time.time() + total_bytes = 0 + while True: + info = self.store.broadcast_obj(obj=None, src=0) + ret = self.engine.transfer_sync_read( + self.buffer_info["session_id"], + self.buf.data_ptr(), + self.buffer_info["ptr"], + info["len"], + ) + assert ret == 0, f"transfer_sync_read failed {ret}" + total_bytes += info["len"] + for name, meta in info["bucket_meta"].items(): + dtype, shape = meta["dtype"], meta["shape"] + size = dtype.itemsize * shape.numel() + tensor = self.buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + yield name, tensor + + self.buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) + + offset = self.buffer_info["ptr"] + self.rank * 4 + ret = self.engine.transfer_sync_write( + self.buffer_info["session_id"], + self.buf.data_ptr(), + offset, + 4, + ) + assert ret == 0, f"transfer_sync_write failed {ret}" + if info["is_last"]: + break + + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + ) From 533aa5e5d3c62cf6b11589d62de28fc3abc45261 Mon Sep 17 00:00:00 2001 From: yexin <469221983@qq.com> Date: Tue, 3 Feb 2026 21:14:12 +0800 Subject: [PATCH 2/2] implement ring algo --- .../mooncake_checkpoint_engine.py | 122 +++++++++++++----- 1 file changed, 87 insertions(+), 35 deletions(-) diff --git a/verl/checkpoint_engine/mooncake_checkpoint_engine.py b/verl/checkpoint_engine/mooncake_checkpoint_engine.py index b446e2e7342..9b09690e215 100644 --- a/verl/checkpoint_engine/mooncake_checkpoint_engine.py +++ b/verl/checkpoint_engine/mooncake_checkpoint_engine.py @@ -78,13 +78,20 @@ def __init__( self.session_id = f"{hostname}:{rpc_port}" self.hostname = hostname - self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device) - assert self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) == 0, "register_memory failed" + self.buf = torch.empty(2 * self.bucket_size, dtype=torch.uint8, device=self.device) + self.magic_buf = torch.empty(4 * 1024, dtype=torch.uint8, device=self.device) + ret = self.engine.batch_register_memory( + [self.buf.data_ptr(), self.magic_buf.data_ptr()], + [2 * self.bucket_size, 4 * 1024], + ) + assert ret == 0, f"batch_register_memory failed ret={ret}" + logger.info(f"__init__ session_id={self.session_id}") def prepare(self) -> dict[str, Any]: """Prepare send and recv buckets""" # self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device) # self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) + logger.info(f"prepare ptr={self.buf.data_ptr():#x} len={2*self.bucket_size} magic_buf_ptr={self.magic_buf.data_ptr():#x}") port, _ = get_free_port(self.hostname) return {"addr": self.hostname, "port": port} @@ -106,6 +113,7 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any self.rank = rank self.world_size = world_size if rank < 0: + logger.info(f"init_process_group rank={rank}") return self.store = StatelessProcessGroup.create( @@ -115,55 +123,74 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any world_size=world_size, ) - if self.is_master: - buffer_info = { - "session_id": self.session_id, - "ptr": self.buf.data_ptr(), - "len": self.bucket_size, - } - self.store.broadcast_obj(obj=buffer_info, src=0) - else: - self.buffer_info = self.store.broadcast_obj(obj=None, src=0) + info = { + "session_id": self.session_id, + "ptr": self.buf.data_ptr(), + } + info_list = self.store.all_gather_obj(info) + self.buffer_info = None if rank == 0 else info_list[rank - 1] + + logger.info( + f"init_process_group rank={rank} world_size={world_size} buffer_info={self.buffer_info}" + ) def finalize(self): """Cleanup communication and deregister memory""" self.store = None get_torch_device().empty_cache() gc.collect() + logger.info(f"finalize rank={self.rank}") - async def wait_for_complete(self): + async def wait_for_complete(self, buf: torch.Tensor): magic = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) - target = magic.repeat(self.world_size - 1) while True: - if torch.equal(self.buf[4:4 * self.world_size], target): + if torch.equal(buf[:4], magic): break await asyncio.sleep(0) @torch.no_grad() async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): """Send weights using Mooncake TransferEngine""" + if self.rank < 0: + for name, weight in weights: + pass + logger.info(f"send_weights rank={self.rank}") + return + + total_bytes = 0 start_time = time.time() bucket_meta: dict[str, TensorMeta] = {} offset = 0 + should_wait = False + bufs = [self.buf[:self.bucket_size], self.buf[self.bucket_size:]] + idx = 0 + current = bufs[idx] for name, weight in weights: - if self.rank != 0: - continue weight = weight.to(self.rollout_dtype) if offset + weight.nbytes > self.bucket_size: - get_torch_device().synchronize + total_bytes += offset + get_torch_device().synchronize() info = { "bucket_meta": bucket_meta, + "ptr": current.data_ptr(), "len": offset, "is_last": False, } - self.store.broadcast_obj(obj=info, src=0) - await self.wait_for_complete() + # send to rank 1 + self.store.send_obj(info, 1) + + idx ^= 1 + current = bufs[idx] bucket_meta = {} offset = 0 + if should_wait: + await self.wait_for_complete(current) + should_wait = True + assert offset + weight.nbytes <= self.bucket_size, ( f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." ) @@ -174,53 +201,78 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, "dtype": weight.dtype, "offset": offset, } - self.buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + current[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) offset += weight.nbytes - if self.rank != 0: - return - get_torch_device().synchronize() info = { "bucket_meta": bucket_meta, + "ptr": current.data_ptr(), "len": offset, "is_last": True, } - self.store.broadcast_obj(obj=info, src=0) - await self.wait_for_complete() - logger.info(f"send weights done, time cost: {time.time() - start_time:.2f}s") + self.store.send_obj(info, 1) + await self.wait_for_complete(current) + + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} send weights done, " + f"total bytes: {total_bytes} time cost: {time_cost:.2f}s bandwidth: {bandwidth:.2f} GB/s" + ) @torch.no_grad() async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: """Receive weights using Mooncake TransferEngine""" start_time = time.time() total_bytes = 0 + bufs = [self.buf[:self.bucket_size], self.buf[self.bucket_size:]] + idx = 0 + current = bufs[idx] + self.magic_buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) + while True: - info = self.store.broadcast_obj(obj=None, src=0) + # 1 receive info from previous rank + info = self.store.recv_obj(self.rank - 1) + if idx >= 2 and self.rank < self.world_size - 1: + await self.wait_for_complete(current) + + ptr = info["ptr"] ret = self.engine.transfer_sync_read( self.buffer_info["session_id"], - self.buf.data_ptr(), - self.buffer_info["ptr"], + current.data_ptr(), + ptr, info["len"], ) assert ret == 0, f"transfer_sync_read failed {ret}" total_bytes += info["len"] + + # 2 send info to next rank + info["ptr"] = current.data_ptr() + if self.rank < self.world_size - 1: + self.store.send_obj(info, self.rank + 1) + + # 3 yield tensor from current buffer for name, meta in info["bucket_meta"].items(): dtype, shape = meta["dtype"], meta["shape"] size = dtype.itemsize * shape.numel() - tensor = self.buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) + tensor = current[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) yield name, tensor - self.buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device) - - offset = self.buffer_info["ptr"] + self.rank * 4 + # 4 write magic data to previous rank ret = self.engine.transfer_sync_write( self.buffer_info["session_id"], - self.buf.data_ptr(), - offset, + self.magic_buf.data_ptr(), + ptr, 4, ) assert ret == 0, f"transfer_sync_write failed {ret}" + + # 5 swap buffer + idx += 1 + current = bufs[idx % 2] + get_torch_device().synchronize() + if info["is_last"]: break