Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
implement ring algo
  • Loading branch information
yexin committed Feb 5, 2026
commit 533aa5e5d3c62cf6b11589d62de28fc3abc45261
122 changes: 87 additions & 35 deletions verl/checkpoint_engine/mooncake_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ def __init__(
self.session_id = f"{hostname}:{rpc_port}"
self.hostname = hostname

self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
assert self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) == 0, "register_memory failed"
self.buf = torch.empty(2 * self.bucket_size, dtype=torch.uint8, device=self.device)
self.magic_buf = torch.empty(4 * 1024, dtype=torch.uint8, device=self.device)
ret = self.engine.batch_register_memory(
[self.buf.data_ptr(), self.magic_buf.data_ptr()],
[2 * self.bucket_size, 4 * 1024],
)
assert ret == 0, f"batch_register_memory failed ret={ret}"
logger.info(f"__init__ session_id={self.session_id}")

def prepare(self) -> dict[str, Any]:
"""Prepare send and recv buckets"""
# self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
# self.engine.register_memory(self.buf.data_ptr(), self.bucket_size)
logger.info(f"prepare ptr={self.buf.data_ptr():#x} len={2*self.bucket_size} magic_buf_ptr={self.magic_buf.data_ptr():#x}")
port, _ = get_free_port(self.hostname)
return {"addr": self.hostname, "port": port}

Expand All @@ -106,6 +113,7 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any
self.rank = rank
self.world_size = world_size
if rank < 0:
logger.info(f"init_process_group rank={rank}")
return

self.store = StatelessProcessGroup.create(
Expand All @@ -115,55 +123,74 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any
world_size=world_size,
)

if self.is_master:
buffer_info = {
"session_id": self.session_id,
"ptr": self.buf.data_ptr(),
"len": self.bucket_size,
}
self.store.broadcast_obj(obj=buffer_info, src=0)
else:
self.buffer_info = self.store.broadcast_obj(obj=None, src=0)
info = {
"session_id": self.session_id,
"ptr": self.buf.data_ptr(),
}

info_list = self.store.all_gather_obj(info)
self.buffer_info = None if rank == 0 else info_list[rank - 1]

logger.info(
f"init_process_group rank={rank} world_size={world_size} buffer_info={self.buffer_info}"
)

def finalize(self):
"""Cleanup communication and deregister memory"""
self.store = None
get_torch_device().empty_cache()
gc.collect()
logger.info(f"finalize rank={self.rank}")

async def wait_for_complete(self):
async def wait_for_complete(self, buf: torch.Tensor):
magic = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)
target = magic.repeat(self.world_size - 1)
while True:
if torch.equal(self.buf[4:4 * self.world_size], target):
if torch.equal(buf[:4], magic):
break
await asyncio.sleep(0)

@torch.no_grad()
async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
"""Send weights using Mooncake TransferEngine"""
if self.rank < 0:
for name, weight in weights:
pass
logger.info(f"send_weights rank={self.rank}")
return

total_bytes = 0
start_time = time.time()
bucket_meta: dict[str, TensorMeta] = {}
offset = 0
should_wait = False
bufs = [self.buf[:self.bucket_size], self.buf[self.bucket_size:]]
idx = 0
current = bufs[idx]

for name, weight in weights:
if self.rank != 0:
continue
weight = weight.to(self.rollout_dtype)

if offset + weight.nbytes > self.bucket_size:
get_torch_device().synchronize
total_bytes += offset
get_torch_device().synchronize()
info = {
"bucket_meta": bucket_meta,
"ptr": current.data_ptr(),
"len": offset,
"is_last": False,
}
self.store.broadcast_obj(obj=info, src=0)
await self.wait_for_complete()
# send to rank 1
self.store.send_obj(info, 1)

idx ^= 1
current = bufs[idx]
bucket_meta = {}
offset = 0

if should_wait:
await self.wait_for_complete(current)
should_wait = True

assert offset + weight.nbytes <= self.bucket_size, (
f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket."
)
Expand All @@ -174,53 +201,78 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
"dtype": weight.dtype,
"offset": offset,
}
self.buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
current[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
offset += weight.nbytes

if self.rank != 0:
return

get_torch_device().synchronize()
info = {
"bucket_meta": bucket_meta,
"ptr": current.data_ptr(),
"len": offset,
"is_last": True,
}
self.store.broadcast_obj(obj=info, src=0)
await self.wait_for_complete()
logger.info(f"send weights done, time cost: {time.time() - start_time:.2f}s")
self.store.send_obj(info, 1)
await self.wait_for_complete(current)

time_cost = time.time() - start_time
bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024)
logger.info(
f"Rank {self.rank} send weights done, "
f"total bytes: {total_bytes} time cost: {time_cost:.2f}s bandwidth: {bandwidth:.2f} GB/s"
)

@torch.no_grad()
async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
"""Receive weights using Mooncake TransferEngine"""
start_time = time.time()
total_bytes = 0
bufs = [self.buf[:self.bucket_size], self.buf[self.bucket_size:]]
idx = 0
current = bufs[idx]
self.magic_buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)

while True:
info = self.store.broadcast_obj(obj=None, src=0)
# 1 receive info from previous rank
info = self.store.recv_obj(self.rank - 1)
if idx >= 2 and self.rank < self.world_size - 1:
await self.wait_for_complete(current)

ptr = info["ptr"]
ret = self.engine.transfer_sync_read(
self.buffer_info["session_id"],
self.buf.data_ptr(),
self.buffer_info["ptr"],
current.data_ptr(),
ptr,
info["len"],
)
assert ret == 0, f"transfer_sync_read failed {ret}"
total_bytes += info["len"]

# 2 send info to next rank
info["ptr"] = current.data_ptr()
if self.rank < self.world_size - 1:
self.store.send_obj(info, self.rank + 1)

# 3 yield tensor from current buffer
for name, meta in info["bucket_meta"].items():
dtype, shape = meta["dtype"], meta["shape"]
size = dtype.itemsize * shape.numel()
tensor = self.buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape)
tensor = current[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape)
yield name, tensor

self.buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)

offset = self.buffer_info["ptr"] + self.rank * 4
# 4 write magic data to previous rank
ret = self.engine.transfer_sync_write(
self.buffer_info["session_id"],
self.buf.data_ptr(),
offset,
self.magic_buf.data_ptr(),
ptr,
4,
)
assert ret == 0, f"transfer_sync_write failed {ret}"

# 5 swap buffer
idx += 1
current = bufs[idx % 2]
get_torch_device().synchronize()

if info["is_last"]:
break

Expand Down