Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
eb3f5db
feat: merge the logic of _update_per_bucket_p2p into _update_per_bucket
abcdea Sep 23, 2025
2898e16
fix: refine rank group handling in update
abcdea Sep 28, 2025
0065a87
feat: GPU-RDMA-device-topology-unawared receiver assignment implemented
abcdea Sep 28, 2025
4ada8b6
feat: GPU-RDMA-device-topology-awared receiver assignment implemented
abcdea Sep 30, 2025
c840c34
style: ruff formatting
abcdea Oct 9, 2025
59b8b38
fix: resolve PR comments
specture724 Oct 14, 2025
745e4c2
Merge branch 'main' into feat/optimize_p2p
specture724 Oct 14, 2025
e2e98d5
fix: assert p2p_store_addr is not None in loading checkpoint
specture724 Oct 14, 2025
d8dc4be
misc: logging removed
specture724 Oct 14, 2025
17271e5
Merge branch 'main' into feat/optimize_p2p
specture724 Oct 14, 2025
c6ec7dd
fix: return logic in update
specture724 Oct 14, 2025
5916eb9
Merge branch 'feat/optimize_p2p' of github.com:specture724/checkpoint…
specture724 Oct 14, 2025
bf73f42
fix: resolve pr comment issues
specture724 Oct 14, 2025
e79ef26
misc: format commit message
specture724 Oct 14, 2025
e126a89
feat: test_assign_receiver_ranks.py: add unit tests for _assign_recei…
specture724 Oct 14, 2025
8440ffa
refactor: refactor the test
specture724 Oct 17, 2025
5fd38b1
misc: resolve pr issues
specture724 Oct 17, 2025
0d74d7f
fix: handle corner case when senders' buckets lays unbanlancedly
specture724 Oct 20, 2025
f06b0bf
fix: debug test and add more cases
specture724 Oct 20, 2025
32e687d
Merge branch 'MoonshotAI:main' into feat/optimize_p2p
specture724 Oct 20, 2025
12455f5
misc: fix pr issues
specture724 Oct 21, 2025
e4c253d
doc: fix benchmark results in README
specture724 Oct 23, 2025
4bd5da3
fix: gather meta to generate local topo fixed
specture724 Oct 25, 2025
86437ad
misc: fix pr issues
specture724 Oct 25, 2025
415c320
doc
specture724 Oct 28, 2025
ed6c8a0
docs: add numa binding
weixiao-huang Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: GPU-RDMA-device-topology-unawared receiver assignment implemented
  • Loading branch information
abcdea committed Oct 9, 2025
commit 0065a87106a3db1ea99155a2b09e39077a10f215
63 changes: 51 additions & 12 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,16 +517,33 @@ def _gen_h2d_buckets(
assert buckets[-1][1].size > 0, (
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
)
new_buckets: list[tuple[int, int, H2DBucket]] = []
# (owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
buckets_with_receiver = _assign_receiver_ranks(buckets, ranks or list(range(len(global_metas))))
return buckets_with_receiver


def _assign_receiver_ranks(
buckets: list[tuple[int, H2DBucket]], ranks: list[int]
) -> list[tuple[int, int, H2DBucket]]:
"""
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)

Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
GPU-NIC topology will be considered to make full use of the bandwidth in the future.
Now, if owner_rank is not in ranks, assign the bucket to the first rank in ranks.
Assign owner_rank as receiver_rank if ranks is empty or contains only the owner_rank, ignoring the topology.
"""
buckets_with_receiver: list[tuple[int, int, H2DBucket]] = []
# TODO: this is a simple implementation, we simply assign the bucket to the first rank in ranks
# which may cause imbalance if ranks is not balanced. We can improve this by detecting topology.
for i, (owner_rank, bucket) in enumerate(buckets):
if ranks and owner_rank not in ranks:
new_buckets.append((ranks[0], owner_rank, bucket))
logger.warning(f"[rank{dist.get_rank()}] bucket {i} owner_rank {owner_rank} not in ranks {ranks}, assign to {new_buckets[-1][0]}")
buckets_with_receiver.append((ranks[0], owner_rank, bucket))
logger.warning(
f"[rank{dist.get_rank()}] bucket {i} owner_rank {owner_rank} not in ranks {ranks}, assign to {buckets_with_receiver[-1][0]}"
)
else:
new_buckets.append((owner_rank, owner_rank, bucket))

return new_buckets
buckets_with_receiver.append((owner_rank, owner_rank, bucket))
return buckets_with_receiver


def _get_master_port(master_port: int | None = None) -> int:
Expand Down Expand Up @@ -982,6 +999,17 @@ def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
[f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
)

def _get_bcast_rank_map(self, ranks: list[int]) -> dict[int, int]:
Comment thread
specture724 marked this conversation as resolved.
Outdated
# map rank to the rank which is in the same machine and has local_rank 0
bcast_rank_map = {}
if not ranks:
for r in range(self._world_size):
bcast_rank_map[r] = r
else:
for i, r in enumerate(ranks):
bcast_rank_map[r] = i
return bcast_rank_map

def _update_per_bucket(
self,
checkpoint_name: str,
Expand All @@ -991,7 +1019,7 @@ def _update_per_bucket(
logger.warning(f"[rank{self._rank}] Using _update_per_bucket, which is an experimental feature.")
assert req_func is not None
Comment thread
specture724 marked this conversation as resolved.
Outdated
# if both ranks is None or [], it will use fully broadcast to update to all ranks
if not ranks:
if not ranks:
if len(self._current_global_parameter_metas) == 0:
raise ValueError("parameter metas is empty")

Expand Down Expand Up @@ -1030,7 +1058,9 @@ def _update_per_bucket(
# p2p store need to register h2d_buffer to let other ranks read
if ranks:
h2d_buffer_name = "__h2d_buffer__"
self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer}) if h2d_buffer is not None else None
self._p2p_store.register_named_tensors(
{h2d_buffer_name: h2d_buffer}
) if h2d_buffer is not None else None
Comment thread
specture724 marked this conversation as resolved.
Outdated

receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
for receiver_rank, owner_rank, bucket in buckets:
Expand All @@ -1043,7 +1073,7 @@ def _update_per_bucket(

buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
max_len = 0
for receiver_rank, owner_rank, bucket in buckets:
for receiver_rank, _, bucket in buckets:
buckets_by_receiver_rank[receiver_rank].append(bucket)
if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
max_len = len(buckets_by_receiver_rank[receiver_rank])
Expand All @@ -1062,7 +1092,12 @@ def _update_per_bucket(
if not ranks:
self._copy_to_buffer(checkpoint_name, receiver_rank_buckets[i][1], h2d_buffer)
else:
self._copy_to_buffer(checkpoint_name, receiver_rank_buckets[i][1], h2d_buffer, receiver_rank_buckets[i][0])
self._copy_to_buffer(
checkpoint_name,
receiver_rank_buckets[i][1],
h2d_buffer,
receiver_rank_buckets[i][0],
)

for receiver_rank, _buckets in buckets_by_receiver_rank.items():
if i >= len(_buckets):
Expand All @@ -1084,7 +1119,8 @@ def _update_per_bucket(
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
else:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
dist.broadcast(buffer_b, src=receiver_rank)
brank = self._get_bcast_rank_map(ranks)[receiver_rank]
Comment thread
specture724 marked this conversation as resolved.
Outdated
dist.broadcast(buffer_b, src=brank)
socket.recv()
dist.barrier()
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
Expand All @@ -1096,6 +1132,9 @@ def _update_per_bucket(
req_thread.join()
dist.barrier()
socket.close()
if ranks and h2d_buffer is not None:
self._p2p_store.unregister_named_tensors([h2d_buffer_name])

torch.cuda.empty_cache()


Expand Down