Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
eb3f5db
feat: merge the logic of _update_per_bucket_p2p into _update_per_bucket
abcdea Sep 23, 2025
2898e16
fix: refine rank group handling in update
abcdea Sep 28, 2025
0065a87
feat: GPU-RDMA-device-topology-unawared receiver assignment implemented
abcdea Sep 28, 2025
4ada8b6
feat: GPU-RDMA-device-topology-awared receiver assignment implemented
abcdea Sep 30, 2025
c840c34
style: ruff formatting
abcdea Oct 9, 2025
59b8b38
fix: resolve PR comments
specture724 Oct 14, 2025
745e4c2
Merge branch 'main' into feat/optimize_p2p
specture724 Oct 14, 2025
e2e98d5
fix: assert p2p_store_addr is not None in loading checkpoint
specture724 Oct 14, 2025
d8dc4be
misc: logging removed
specture724 Oct 14, 2025
17271e5
Merge branch 'main' into feat/optimize_p2p
specture724 Oct 14, 2025
c6ec7dd
fix: return logic in update
specture724 Oct 14, 2025
5916eb9
Merge branch 'feat/optimize_p2p' of github.com:specture724/checkpoint…
specture724 Oct 14, 2025
bf73f42
fix: resolve pr comment issues
specture724 Oct 14, 2025
e79ef26
misc: format commit message
specture724 Oct 14, 2025
e126a89
feat: test_assign_receiver_ranks.py: add unit tests for _assign_recei…
specture724 Oct 14, 2025
8440ffa
refactor: refactor the test
specture724 Oct 17, 2025
5fd38b1
misc: resolve pr issues
specture724 Oct 17, 2025
0d74d7f
fix: handle corner case when senders' buckets lays unbanlancedly
specture724 Oct 20, 2025
f06b0bf
fix: debug test and add more cases
specture724 Oct 20, 2025
32e687d
Merge branch 'MoonshotAI:main' into feat/optimize_p2p
specture724 Oct 20, 2025
12455f5
misc: fix pr issues
specture724 Oct 21, 2025
e4c253d
doc: fix benchmark results in README
specture724 Oct 23, 2025
4bd5da3
fix: gather meta to generate local topo fixed
specture724 Oct 25, 2025
86437ad
misc: fix pr issues
specture724 Oct 25, 2025
415c320
doc
specture724 Oct 28, 2025
ed6c8a0
docs: add numa binding
weixiao-huang Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: resolve PR comments
  • Loading branch information
specture724 committed Oct 14, 2025
commit 59b8b38a925edf0ce93c5d3e1f079d937d3f08d5
38 changes: 23 additions & 15 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ class MemoryBuffer(BaseModel):
class MemoryBufferMetaList(BaseModel):
p2p_store_addr: str | None
memory_buffer_metas_list: list[MemoryBufferMetas]
rdma_device: str


class DataToGather(MemoryBufferMetaList):
host_ip: str
device_uuid: str
rdma_device: str


# 256 bytes alignment when flatten torch tensors to uint8 buffer
Expand Down Expand Up @@ -527,25 +527,24 @@ def _gen_h2d_buckets(
if ranks
else local_topo
)
buckets_with_receiver = _assign_receiver_ranks(buckets, actual_local_topo, remote_topo, ranks)
return buckets_with_receiver
# if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
if not ranks:
return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
else:
return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)


def _assign_receiver_ranks(
buckets: list[tuple[int, H2DBucket]],
local_topo: dict[str, set[int]],
remote_topo: dict[str, set[int]],
ranks: list[int] | None = None,
) -> list[tuple[int, int, H2DBucket]]:
"""
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)

Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
GPU-rdma_device topology will be considered to make full use of the bandwidth.
"""
# if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
if not ranks:
return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
rank_to_rdma_device = {
rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
}
Expand Down Expand Up @@ -573,7 +572,7 @@ def _assign_receiver_ranks(
for i, (owner_rank, bucket) in enumerate(flattened_buckets):
receiver_rank = receiver_list[i % len(receiver_list)]
buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
logger.debug(
logger.info(
f"Assigned bucket with owner_rank {owner_rank} to receiver_rank {receiver_rank}"
)

Expand Down Expand Up @@ -692,7 +691,7 @@ def __init__(
device_index = self._local_rank
torch.cuda.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(device_index)
self.rdma_device = None if self._p2p_store is None else self._p2p_store.device
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device

def _logger_rank0(self, msg: str):
if self._local_rank == 0:
Expand All @@ -703,10 +702,15 @@ def get_metas(self) -> dict[int, MemoryBufferMetaList]:

def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
self._current_global_parameter_metas = metas
for i, meta in self._current_global_parameter_metas.items():
self._remote_rdma_devices[
meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
].add(i)
self._remote_rdma_devices = defaultdict(set)
try:
for i, meta in self._current_global_parameter_metas.items():
self._remote_rdma_devices[
meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
].add(i)
except AttributeError as e:
Comment thread
specture724 marked this conversation as resolved.
Outdated
self._remote_rdma_devices = self._local_rdma_devices.copy()
logger.warning(f"[rank{self._rank}] encountered {e}, use local rdma devices as remote")

def register_checkpoint(
self,
Expand Down Expand Up @@ -780,7 +784,7 @@ def gather_metas(self, checkpoint_name: str):
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
host_ip=_get_ip(),
device_uuid=self._device_uuid,
rdma_device=self.rdma_device or "",
rdma_device=self._rdma_device or "",
)

dist.all_gather_object(metas_lst, metas)
Expand All @@ -795,7 +799,11 @@ def gather_metas(self, checkpoint_name: str):
if not self._global_device_uuids:
global_device_uuids.append(metas_buckets.device_uuid)
if metas_buckets.memory_buffer_metas_list:
self._current_global_parameter_metas[i] = metas_buckets
self._current_global_parameter_metas[i] = MemoryBufferMetaList(
memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
p2p_store_addr=metas_buckets.p2p_store_addr,
rdma_device=metas_buckets.rdma_device,
)
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
self._local_rdma_devices[metas_buckets.rdma_device + "@" + metas.host_ip].add(i)
if not self._all_hosts:
Expand Down