Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
eb3f5db
feat: merge the logic of _update_per_bucket_p2p into _update_per_bucket
abcdea Sep 23, 2025
2898e16
fix: refine rank group handling in update
abcdea Sep 28, 2025
0065a87
feat: GPU-RDMA-device-topology-unawared receiver assignment implemented
abcdea Sep 28, 2025
4ada8b6
feat: GPU-RDMA-device-topology-awared receiver assignment implemented
abcdea Sep 30, 2025
c840c34
style: ruff formatting
abcdea Oct 9, 2025
59b8b38
fix: resolve PR comments
specture724 Oct 14, 2025
745e4c2
Merge branch 'main' into feat/optimize_p2p
specture724 Oct 14, 2025
e2e98d5
fix: assert p2p_store_addr is not None in loading checkpoint
specture724 Oct 14, 2025
d8dc4be
misc: logging removed
specture724 Oct 14, 2025
17271e5
Merge branch 'main' into feat/optimize_p2p
specture724 Oct 14, 2025
c6ec7dd
fix: return logic in update
specture724 Oct 14, 2025
5916eb9
Merge branch 'feat/optimize_p2p' of github.com:specture724/checkpoint…
specture724 Oct 14, 2025
bf73f42
fix: resolve pr comment issues
specture724 Oct 14, 2025
e79ef26
misc: format commit message
specture724 Oct 14, 2025
e126a89
feat: test_assign_receiver_ranks.py: add unit tests for _assign_recei…
specture724 Oct 14, 2025
8440ffa
refactor: refactor the test
specture724 Oct 17, 2025
5fd38b1
misc: resolve pr issues
specture724 Oct 17, 2025
0d74d7f
fix: handle corner case when senders' buckets lays unbanlancedly
specture724 Oct 20, 2025
f06b0bf
fix: debug test and add more cases
specture724 Oct 20, 2025
32e687d
Merge branch 'MoonshotAI:main' into feat/optimize_p2p
specture724 Oct 20, 2025
12455f5
misc: fix pr issues
specture724 Oct 21, 2025
e4c253d
doc: fix benchmark results in README
specture724 Oct 23, 2025
4bd5da3
fix: gather meta to generate local topo fixed
specture724 Oct 25, 2025
86437ad
misc: fix pr issues
specture724 Oct 25, 2025
415c320
doc
specture724 Oct 28, 2025
ed6c8a0
docs: add numa binding
weixiao-huang Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: GPU-RDMA-device-topology-awared receiver assignment implemented
  • Loading branch information
abcdea committed Oct 9, 2025
commit 4ada8b6df8aa70b6c5782bd962efe1e1598e25bc
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
.ruff_cache
.DS_Store
.idea
.vscode/
build/
dist/
_version.py
*_perf/
106 changes: 79 additions & 27 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class MemoryBufferMetaList(BaseModel):
class DataToGather(MemoryBufferMetaList):
host_ip: str
device_uuid: str
rdma_device: str
Comment thread
specture724 marked this conversation as resolved.
Outdated


# 256 bytes alignment when flatten torch tensors to uint8 buffer
Expand Down Expand Up @@ -493,7 +494,11 @@ def request_inference_to_update(


def _gen_h2d_buckets(
global_metas: dict[int, MemoryBufferMetaList], bucket_size: int, ranks: list[int] | None = None
global_metas: dict[int, MemoryBufferMetaList],
bucket_size: int,
local_topo: dict[str, set[int]],
remote_topo: dict[str, set[int]],
ranks: list[int] | None = None,
) -> list[tuple[int, int, H2DBucket]]:
buckets: list[tuple[int, H2DBucket]] = []
Comment thread
specture724 marked this conversation as resolved.

Expand All @@ -517,32 +522,61 @@ def _gen_h2d_buckets(
assert buckets[-1][1].size > 0, (
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
)
buckets_with_receiver = _assign_receiver_ranks(buckets, ranks or list(range(len(global_metas))))
actual_local_topo = (
{k: v & set(ranks) for k, v in local_topo.items() if v & set(ranks)}
Comment thread
specture724 marked this conversation as resolved.
Outdated
if ranks
else local_topo
)
buckets_with_receiver = _assign_receiver_ranks(buckets, actual_local_topo, remote_topo, ranks)
return buckets_with_receiver


def _assign_receiver_ranks(
buckets: list[tuple[int, H2DBucket]], ranks: list[int]
buckets: list[tuple[int, H2DBucket]],
local_topo: dict[str, set[int]],
remote_topo: dict[str, set[int]],
ranks: list[int] | None = None,
) -> list[tuple[int, int, H2DBucket]]:
"""
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)

Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
GPU-NIC topology will be considered to make full use of the bandwidth in the future.
Now, if owner_rank is not in ranks, assign the bucket to the first rank in ranks.
Assign owner_rank as receiver_rank if ranks is empty or contains only the owner_rank, ignoring the topology.
GPU-rdma_device topology will be considered to make full use of the bandwidth.
"""
buckets_with_receiver: list[tuple[int, int, H2DBucket]] = []
# TODO: this is a simple implementation, we simply assign the bucket to the first rank in ranks
# which may cause imbalance if ranks is not balanced. We can improve this by detecting topology.
for i, (owner_rank, bucket) in enumerate(buckets):
if ranks and owner_rank not in ranks:
buckets_with_receiver.append((ranks[0], owner_rank, bucket))
logger.warning(
f"[rank{dist.get_rank()}] bucket {i} owner_rank {owner_rank} not in ranks {ranks}, assign to {buckets_with_receiver[-1][0]}"
)
else:
buckets_with_receiver.append((owner_rank, owner_rank, bucket))
# if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
if not ranks:
Comment thread
specture724 marked this conversation as resolved.
Outdated
return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
rank_to_rdma_device = {
rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
}

# group buckets by owner RDMA devices
buckets_by_rdma_device = defaultdict(list)
for owner_rank, bucket in buckets:
owner_rdma_device = rank_to_rdma_device[owner_rank]
buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))

buckets_matrix = list(buckets_by_rdma_device.values())

# select receiver ranks
num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
Comment thread
specture724 marked this conversation as resolved.

flattened_buckets = [
buckets_matrix[row][col]
for col in range(max(len(col) for col in buckets_matrix) if buckets_matrix else 0)
Comment thread
specture724 marked this conversation as resolved.
Outdated
for row in range(len(buckets_matrix))
if col < len(buckets_matrix[row])
]

buckets_with_receiver = []
for i, (owner_rank, bucket) in enumerate(flattened_buckets):
receiver_rank = receiver_list[i % len(receiver_list)]
buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
logger.debug(
f"Assigned bucket with owner_rank {owner_rank} to receiver_rank {receiver_rank}"
)

return buckets_with_receiver


Expand All @@ -561,14 +595,14 @@ def __init__(self):
self.rank = int(os.getenv("RANK"))
gpu_count = torch.cuda.device_count()
local_rank = self.rank % gpu_count
device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
self.ip = _get_ip()

# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
retry_count = 8
for i in range(retry_count):
self.engine = TransferEngine()
ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", device)
ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device)
if ret == 0:
break
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
Expand All @@ -582,7 +616,7 @@ def __init__(self):
self.port = self.engine.get_rpc_port()
self.named_tensors: dict[str, torch.Tensor] = {}
logger.info(
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
)

@property
Expand Down Expand Up @@ -637,6 +671,8 @@ def __init__(
self._auto_pg = auto_pg
self._all_hosts = []
self._global_device_uuids: list[str] = []
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)

assert self._rank is not None and self._rank >= 0, self._rank
assert self._world_size and self._world_size > 0, self._world_size
Expand All @@ -656,6 +692,7 @@ def __init__(
device_index = self._local_rank
torch.cuda.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(device_index)
self.rdma_device = None if self._p2p_store is None else self._p2p_store.device
Comment thread
specture724 marked this conversation as resolved.
Outdated

def _logger_rank0(self, msg: str):
if self._local_rank == 0:
Expand All @@ -666,6 +703,10 @@ def get_metas(self) -> dict[int, MemoryBufferMetaList]:

def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
self._current_global_parameter_metas = metas
for i, meta in self._current_global_parameter_metas.items():
self._remote_rdma_devices[
Comment thread
specture724 marked this conversation as resolved.
meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
Comment thread
specture724 marked this conversation as resolved.
].add(i)

def register_checkpoint(
self,
Expand Down Expand Up @@ -739,11 +780,11 @@ def gather_metas(self, checkpoint_name: str):
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
host_ip=_get_ip(),
device_uuid=self._device_uuid,
rdma_device=self.rdma_device or "",
)

dist.all_gather_object(metas_lst, metas)

self._current_global_parameter_metas = {}
num_parameters = 0
all_hosts: list[str] = []
global_device_uuids: list[str] = []
Expand All @@ -756,10 +797,14 @@ def gather_metas(self, checkpoint_name: str):
if metas_buckets.memory_buffer_metas_list:
self._current_global_parameter_metas[i] = metas_buckets
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
self._local_rdma_devices[metas_buckets.rdma_device + "@" + metas.host_ip].add(i)
if not self._all_hosts:
self._all_hosts = all_hosts
if not self._global_device_uuids:
self._global_device_uuids = global_device_uuids
# Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
# Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
self._remote_rdma_devices = self._local_rdma_devices.copy()
logger.info(
f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
)
Expand Down Expand Up @@ -1000,11 +1045,13 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
)

def _get_bcast_rank_map(self, ranks: list[int]) -> dict[int, int]:
Comment thread
specture724 marked this conversation as resolved.
Outdated
# map rank to the rank which is in the same machine and has local_rank 0
bcast_rank_map = {}
"""
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
which are generated in self.init_process_group_for_ranks
"""
bcast_rank_map: dict[int, int] = {}
if not ranks:
for r in range(self._world_size):
bcast_rank_map[r] = r
bcast_rank_map = {r: r for r in range(self._world_size)}
else:
for i, r in enumerate(ranks):
bcast_rank_map[r] = i
Expand Down Expand Up @@ -1048,7 +1095,13 @@ def _update_per_bucket(
dist.barrier()

bucket_size, disable_h2d_buffer = self._detect_bucket_size()
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size, ranks)
buckets = _gen_h2d_buckets(
self._current_global_parameter_metas,
bucket_size,
self._local_rdma_devices,
self._remote_rdma_devices,
ranks,
)

h2d_buffer: torch.Tensor | None = (
None
Expand Down Expand Up @@ -1098,7 +1151,6 @@ def _update_per_bucket(
h2d_buffer,
receiver_rank_buckets[i][0],
)

for receiver_rank, _buckets in buckets_by_receiver_rank.items():
if i >= len(_buckets):
continue
Expand Down
3 changes: 1 addition & 2 deletions tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id
from checkpoint_engine.worker import update_weights_from_ipc
from loguru import logger


def gen_test_tensors(rank: int) -> list[tuple[str, torch.Tensor]]:
tensors = []
Expand Down Expand Up @@ -78,7 +78,6 @@ def run():
ps.gather_metas(checkpoint_name)
ranks_list = [[], list(range(world_size // 2)), [], list(range(world_size))]
for ranks in ranks_list:
logger.warning(f"Update with ranks: {ranks}")
ps.update(checkpoint_name, queue.put, ranks=ranks)
# sleep 3s to wait process group is destroyed
time.sleep(3)
Expand Down
Loading