From e020f21b03a8284528264b820b76b975afaa9426 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Thu, 7 Aug 2025 11:48:14 +0000 Subject: [PATCH] revert kvcache transfer Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../batch_manager/cacheFormatter.cpp | 7 +++--- .../batch_manager/mlaCacheFormatter.cpp | 22 ++++++------------- .../batch_manager/cacheTransceiverTest.cpp | 8 +++---- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index d95ca1b412b..2edfd5f77a3 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -75,6 +75,7 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques bool CacheFormatter::needSendCache( CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx) { + // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); if (targetInfo.mDupHeadFactor <= 1) { @@ -89,9 +90,8 @@ bool CacheFormatter::needSendCache( = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup; } - int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor); + return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0; } void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig, @@ -128,12 +128,11 @@ std::vector CacheFormatter::pickRecvConnections( return ret; } TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); - int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; std::vector ret; for (int i = 0; i < targetInfo.mDomainTPSize; i++) { - if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor)) + if (i % targetInfo.mPeerDupHeadFactor == 0) { for (int j = 0; j < targetInfo.mDomainPPSize; j++) { diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 824a31129f8..810edd6f451 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -45,12 +45,10 @@ std::vector MLACacheFormatter::pickRecvConnections( auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); std::vector ret; - // targetInfo , mRanks [tpranks, ppranks] - int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; - + // targetInfo , mRanks [tpranks, dpranks] for (int i = 0; i < targetInfo.mDomainPPSize; i++) { - ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize); + ret.push_back(i); } return ret; } @@ -60,24 +58,19 @@ bool MLACacheFormatter::needSendCache( { int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; - int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP - ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize - : destConfig.getParallelConfig().mTensorParallelism; - int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - if (selfConfig.getParallelConfig().mEnableAttentionDP) { int selfTPNumInDPGroup = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; - + int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP + ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize + : destConfig.getParallelConfig().mTensorParallelism; int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup; if (selfTPNumInDPGroup <= destTPNumInDPGroup) { return true; } - - int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup; - return selfTPrankINDPGroup % dupHeadFactor == destDPRank; + return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0; } int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP @@ -88,8 +81,7 @@ bool MLACacheFormatter::needSendCache( { return true; } - int dupHeadFactor = selfTPNum / destTPNum; - return selfTpRank % dupHeadFactor == destDPRank; + return selfTpRank % (selfTPNum / destTPNum) == 0; } void MLACacheFormatter::format(TransferSession& session) diff --git a/cpp/tests/batch_manager/cacheTransceiverTest.cpp b/cpp/tests/batch_manager/cacheTransceiverTest.cpp index af916359d0d..99c40f810f6 100644 --- a/cpp/tests/batch_manager/cacheTransceiverTest.cpp +++ b/cpp/tests/batch_manager/cacheTransceiverTest.cpp @@ -1457,15 +1457,12 @@ TEST(targetTest, CacheStateNODP) verifyContext( /*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true); - verifyContext( /*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); - verifyContext( /*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); - verifyContext( /*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true); verifyContext( @@ -1477,6 +1474,7 @@ TEST(targetTest, CacheStateNODP) contextTP = 2; genTP = 4; + verifyContext( /*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true); verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, @@ -1566,13 +1564,13 @@ TEST(targetTest, CacheStateContextDP) /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, - /*expectNeedSend*/ false); + /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); verifyContext( /*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, - /*expectNeedSend*/ true); + /*expectNeedSend*/ false); verifyContext( /*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);