diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h index 394f7fb7bfa..0978905b5e2 100644 --- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h +++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h @@ -24,7 +24,6 @@ #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/modelConfig.h" -#include "tensorrt_llm/runtime/request.h" #include "tensorrt_llm/runtime/worldConfig.h" namespace tensorrt_llm::runtime @@ -88,37 +87,6 @@ class CreateNewDecoderRequests : Algorithm SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const; private: - //! @brief Setups decoder internal tensors for new speculative decoding request - static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, - CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode, - SizeType32 maxDecodingEngineTokens); - - //! @brief Setups decoder internal tensors for new request in Draft model Sps mode - static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream); - - //! @brief Setups decoder internal tensors for new Medusa request - static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens); - - //! @brief Setups decoder internal tensors for new Lookahead request - static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - //! @brief Setups decoder internal tensors for new Explicit draft tokens request - static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - //! @brief Setups decoder internal tensors for new Eagle request - static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - [[nodiscard]] std::shared_ptr retrieveDraftLogits(runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig, std::shared_ptr const& tensor, - runtime::BufferManager const& bufferManager) const; - bool mSpeculativeDecodingFastLogits; bool mIsLeaderInOrchMode; bool mIsNormalizeLogProbs; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index e4d13c9e17b..f069e3ac7f5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1110,7 +1110,7 @@ class GenericLlmRequest [[nodiscard]] SizeType32 getNumDraftTokens() const { - return mDraftTokens->size(); + return hasDraftTokens() ? mDraftTokens->size() : 0; } void discardDraftTokens(SizeType32 numTokensToDiscard) diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index deeb0fa0af4..4344f423ac1 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -102,11 +102,13 @@ class DecodingInput { public: TensorPtr draftLogits; + TensorPtr draftLogitsHost; TensorPtr draftProbs; TensorPtr targetProbs; TensorPtr numDraftTokens; TensorPtr numDraftTokensHost; TensorPtr draftTokenIds; + TensorPtr draftTokenIdsHost; TensorPtr useDraftLogits; TensorPtr useDraftLogitsHost; diff --git a/cpp/include/tensorrt_llm/runtime/request.h b/cpp/include/tensorrt_llm/runtime/request.h deleted file mode 100644 index e8f851b7d77..00000000000 --- a/cpp/include/tensorrt_llm/runtime/request.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/executor/executor.h" -#include "tensorrt_llm/runtime/iTensor.h" - -#include - -namespace tensorrt_llm::runtime::decoder_batch -{ - -class Request -{ -public: - using TensorConstPtr = ITensor::SharedConstPtr; - using TensorPtr = ITensor::SharedPtr; - using BufferPtr = IBuffer::SharedPtr; - - explicit Request(SizeType32 inputLen) - : inputLen(inputLen) - { - } - - //! Mandatory parameters - SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps - - // optional parameters - SizeType32 generatedTokensPerEngineStep{1}; // - - //! Optional parameters for speculative decoding - BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu - std::optional draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu - TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu - TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu - std::optional lookaheadRuntimeConfig; - std::optional eagleConfig; -}; - -} // namespace tensorrt_llm::runtime::decoder_batch diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 16771709bb4..3335d69a015 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -20,11 +20,14 @@ #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/medusaBuffers.h" #include "tensorrt_llm/batch_manager/utils/logitsThread.h" +#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -45,6 +48,8 @@ namespace tensorrt_llm::batch_manager using SizeType32 = CreateNewDecoderRequests::SizeType32; using TensorPtr = CreateNewDecoderRequests::TensorPtr; using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr; +template +using OptionalRef = tensorrt_llm::common::OptionalRef; namespace { @@ -320,149 +325,165 @@ void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeT TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -} // namespace - -void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream, CudaStream const& decoderStream, - SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens) +void retrieveDraftLogits(TensorPtr& draftLogitsHost, std::shared_ptr const& reqDraftLogits, + ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool speculativeDecodingFastLogits, + bool isLeaderInOrchMode, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - if (speculativeDecodingMode.predictsDraftTokens()) + if (!speculativeDecodingFastLogits) { - auto const& stream = decoderStream; - BufferManager manager{std::make_shared(stream.get())}; + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + bufferManager.copy(*reqDraftLogits, *draftLogitsHost); + return; + } - auto& dJointOutput = jointDecodingOutput; + if (isLeaderInOrchMode) + { + // reqDraftLogits contains metadata for fast-logits path; validate size. + auto constexpr fastLogitsInfoSize = sizeof(te::SpeculativeDecodingFastLogitsInfo); + TLLM_CHECK_WITH_INFO(reqDraftLogits->getSizeInBytes() >= fastLogitsInfoSize, + "Draft logits metadata buffer is too small to hold SpeculativeDecodingFastLogitsInfo."); + te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo{}; + std::memcpy(&fastLogitsInfo, reqDraftLogits->data(), fastLogitsInfoSize); + utils::targetModelReceiveLogits(draftLogitsHost, fastLogitsInfo, modelConfig.getLogitsDtype()); - TensorPtr nextDraftTokens - = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokens, batchIdx, 1); - // FIXME: can we skip this? - manager.setZero(*nextDraftTokens); - if (speculativeDecodingMode.variableDraftLength()) + // Broadcast to other ranks if needed + if (worldConfig.isTensorParallel()) { - TensorPtr nextDraftTokensLen - = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokensLen, batchIdx, 1); - manager.setZero(*nextDraftTokensLen); + auto const& commSession = COMM_SESSION; + auto shape = draftLogitsHost->getShape(); + commSession.bcastValue(shape.d[0], 0); + commSession.bcastValue(shape.d[1], 0); + commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } } - - if (speculativeDecodingMode.isDraftTokensExternal()) - { - newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream); - } - else if (speculativeDecodingMode.isMedusa()) - { - newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens); - } - else if (speculativeDecodingMode.isLookaheadDecoding()) - { - newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream); - } - else if (speculativeDecodingMode.isExplicitDraftTokens()) - { - newRequestExplicitDraftTokens(batchIdx, request, jointDecodingOutput, runtimeStream); - } - else if (speculativeDecodingMode.isEagle()) + else { - newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream); + TLLM_CHECK_WITH_INFO(worldConfig.isTensorParallel(), + "Fast logits path requires tensor-parallel broadcast for non-leader ranks."); + + // Get logits from leader rank + auto const& commSession = COMM_SESSION; + int64_t dims[2]; + commSession.bcastValue(dims[0], 0); + commSession.bcastValue(dims[1], 0); + draftLogitsHost->reshape(ITensor::makeShape({dims[0], dims[1]})); + commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} +}; -void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream) +//! @brief Setups decoder internal tensors for new request in Draft model Sps mode +void newRequestDraftTokensExternal(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest const& llmReq, + SizeType32 numDecodingEngineTokens, runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, CudaStream const& decoderStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - BufferManager manager{std::make_shared(decoderStream.get())}; + BufferManager decoderBufferManager{std::make_shared(decoderStream.get())}; - auto& dJointInput = jointDecodingInput; + TLLM_CHECK(jointDecodingInput.externalDraftTokensInputs); + auto& externalDraftTokensInputs = jointDecodingInput.externalDraftTokensInputs; - auto const numDraftTokens = request.generatedTokensPerEngineStep - 1; + auto const& draftTokens = llmReq.getDraftTokens(); + auto const numDraftTokens = numDecodingEngineTokens - 1; - auto const useDraftLogits = request.draftLogits.has_value(); - if (useDraftLogits) + auto numDraftTokensHostRange = runtime::BufferRange(*externalDraftTokensInputs->numDraftTokensHost); + numDraftTokensHostRange[batchIdx] = numDraftTokens; + auto numDraftTokensView = ITensor::slice(externalDraftTokensInputs->numDraftTokens, batchIdx, 1); + runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream); + + if (numDraftTokens > 0) { - TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value()); + TensorPtr draftTokenIdsHostSlice + = ITensor::slice(externalDraftTokensInputs->draftTokenIdsHost, {batchIdx, 0}, numDraftTokens); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderBufferManager.copy(draftTokens->data(), *draftTokenIdsHostSlice); - TensorPtr draftLogitsReqBatchSlice - = ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1); - draftLogitsReqBatchSlice->squeeze(0); - TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens); - manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice); + TensorPtr draftTokenIdsSlice + = ITensor::slice(externalDraftTokensInputs->draftTokenIds, {batchIdx, 0}, numDraftTokens); + decoderBufferManager.copy(*draftTokenIdsHostSlice, *draftTokenIdsSlice); } - auto* useDraftLogitsHostPtr = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost); - useDraftLogitsHostPtr[batchIdx] = useDraftLogits; - auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1); + + auto const& draftLogits = llmReq.getDraftLogits(); + auto const useDraftLogits = draftLogits.has_value(); + + auto useDraftLogitsHostRange = runtime::BufferRange(*externalDraftTokensInputs->useDraftLogitsHost); + useDraftLogitsHostRange[batchIdx] = useDraftLogits; + auto useDraftLogitsView = ITensor::slice(externalDraftTokensInputs->useDraftLogits, batchIdx, 1); runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream); - if (numDraftTokens > 0) + if (useDraftLogits) { - TensorPtr draftTokensReqBatchSlice - = ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1); - draftTokensReqBatchSlice->squeeze(0); - TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens); - TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens})); - manager.copy(*draftTokensView, *draftTokensReqTokensSlice); + TensorPtr draftLogitsHostSlice + = ITensor::slice(externalDraftTokensInputs->draftLogitsHost, {batchIdx, 0}, numDraftTokens); + retrieveDraftLogits(draftLogitsHostSlice, draftLogits.value(), modelConfig, worldConfig, + speculativeDecodingFastLogits, isLeaderInOrchMode, decoderBufferManager); + + TensorPtr draftLogitsSlice + = ITensor::slice(externalDraftTokensInputs->draftLogits, {batchIdx, 0}, numDraftTokens); + decoderBufferManager.copy(*draftLogitsHostSlice, *draftLogitsSlice); } - auto* numDraftTokensHostPtr - = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->numDraftTokensHost); - numDraftTokensHostPtr[batchIdx] = numDraftTokens; - auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1); - runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream); - + auto const& samplingConfig = llmReq.mSamplingConfig; bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value(); float const constantThreshold = useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0]; - dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold; - dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold; + externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold; + externalDraftTokensInputs->constantThreshold = constantThreshold; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens) +//! @brief Setups decoder internal tensors for new Medusa request +void newRequestMedusa(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest& llmReq, + SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers, + CudaStream const& decoderStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + llmReq.mSamplingConfig.topKMedusaHeads = {medusaBuffers.mTopKs}; + // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? + // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. + auto medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1); + auto medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1); + BufferManager manager{std::make_shared(decoderStream.get())}; - auto& dJointInput = jointDecodingInput; + auto& medusaInputs = jointDecodingInput.medusaInputs; TensorPtr curTokensPerStepSlice - = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaCurTokensPerStep), batchIdx, 1); + = ITensor::slice(constPointerCast(medusaInputs->medusaCurTokensPerStep), batchIdx, 1); // Context phase Medusa processes 1 token only, new value from targetTokensPerStep will be filled at the end // of first decoder runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream); TensorPtr targetTokensPerStepSlice - = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); - auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep; - TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens, - "Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep, + = ITensor::slice(constPointerCast(medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); + TLLM_CHECK_WITH_INFO(numDecodingEngineTokens <= maxDecodingEngineTokens, + "Tokens per step for (%d) is larger than maximum tokens per step (%d)", numDecodingEngineTokens, maxDecodingEngineTokens); - runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream); + runtime::kernels::invokeFill(*targetTokensPerStepSlice, numDecodingEngineTokens, decoderStream); - TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1); - manager.copy(*request.medusaPaths, *pathsSlice); + TensorPtr pathsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaPaths), batchIdx, 1); + manager.copy(*medusaPaths, *pathsSlice); - TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1); - manager.copy(*request.medusaTreeIds, *treeIdsSlice); + TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaTreeIds), batchIdx, 1); + manager.copy(*medusaTreeIds, *treeIdsSlice); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Lookahead request +void newRequestLookahead(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, + CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.lookaheadOutputs); + TLLM_CHECK(jointDecodingInput.lookaheadInputs); // The first generation step only generate 1 token. TensorPtr curTokensPerStepSlice @@ -472,65 +493,72 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime: TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Explicit draft tokens request +void newRequestExplicitDraftTokens( + DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.explicitDraftTokensBuffers); + auto const inputLen = llmReq.getPromptLen(); + TensorPtr positionIdsBaseSlice = ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1); - runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream); + runtime::kernels::invokeFill(*positionIdsBaseSlice, inputLen, runtimeStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Eagle request +void newRequestEagle(DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, + runtime::ModelConfig const& modelConfig, executor::DecodingConfig const& decodingConfig, + CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.eagleBuffers); + auto& eagleBuffers = *jointDecodingOutput.eagleBuffers; + + auto const inputLen = llmReq.getPromptLen(); BufferManager manager{std::make_shared(runtimeStream.get())}; - TensorPtr eagleNetCtxRequestTypesHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxRequestTypesHost, batchIdx, 1); + TensorPtr eagleNetCtxRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetCtxRequestTypesHost, batchIdx, 1); TensorPtr eagleNetCtxContextLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxContextLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetCtxContextLengthsHost, batchIdx, 1); TensorPtr eagleNetCtxPastKeyValueLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetCtxRequestTypesHostSlice)[0] = 0; - runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = inputLen; + runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = inputLen; - TensorPtr eagleNetGenRequestTypesHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1); + TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetGenRequestTypesHost, batchIdx, 1); TensorPtr eagleNetGenContextLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenContextLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetGenContextLengthsHost, batchIdx, 1); TensorPtr eagleNetGenPastKeyValueLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetGenRequestTypesHostSlice)[0] = 1; - runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = inputLen; + runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = inputLen; auto const eagleModule = std::dynamic_pointer_cast( modelConfig.getSpeculativeDecodingModulePtr()); std::optional eagleChoicesOpt; - if (request.eagleConfig) + auto const& eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig(); + + if (eagleConfig) { - eagleChoicesOpt = request.eagleConfig->getEagleChoices(); + eagleChoicesOpt = eagleConfig->getEagleChoices(); } - if (!request.eagleConfig || !request.eagleConfig->useDynamicTree()) + if (!eagleConfig || !eagleConfig->useDynamicTree()) { - TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1); - TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1); + TensorPtr draftPathsHostSlice = ITensor::slice(eagleBuffers.draftPathsHost, batchIdx, 1); + TensorPtr draftPathsSlice = ITensor::slice(eagleBuffers.draftPaths, batchIdx, 1); // eagleConfig is nullptr or Eagle-1 std::vector topKs; @@ -546,6 +574,61 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +//! @brief Setups decoder internal tensors for new speculative decoding request +void newRequestSpeculativeDecoding(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, + SizeType32 batchIdx, LlmRequest& llmReq, SpeculativeDecodingMode const& speculativeDecodingMode, + SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, + OptionalRef medusaBuffers, runtime::ModelConfig const& modelConfig, + WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, bool speculativeDecodingFastLogits, + bool isLeaderInOrchMode, CudaStream const& runtimeStream, CudaStream const& decoderStream) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (speculativeDecodingMode.predictsDraftTokens()) + { + BufferManager manager{std::make_shared(decoderStream.get())}; + + TLLM_CHECK(jointDecodingOutput.speculativeDecodingOutputs); + auto& speculativeDecodingOutputs = *jointDecodingOutput.speculativeDecodingOutputs; + + TensorPtr nextDraftTokens = ITensor::slice(speculativeDecodingOutputs.nextDraftTokens, batchIdx, 1); + // FIXME: can we skip this? + manager.setZero(*nextDraftTokens); + if (speculativeDecodingMode.variableDraftLength()) + { + TensorPtr nextDraftTokensLen = ITensor::slice(speculativeDecodingOutputs.nextDraftTokensLen, batchIdx, 1); + manager.setZero(*nextDraftTokensLen); + } + } + + if (speculativeDecodingMode.isDraftTokensExternal()) + { + newRequestDraftTokensExternal(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, modelConfig, + worldConfig, speculativeDecodingFastLogits, isLeaderInOrchMode, decoderStream); + } + else if (speculativeDecodingMode.isMedusa()) + { + TLLM_CHECK(medusaBuffers); + newRequestMedusa(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, maxDecodingEngineTokens, + medusaBuffers.value(), decoderStream); + } + else if (speculativeDecodingMode.isLookaheadDecoding()) + { + newRequestLookahead(jointDecodingInput, jointDecodingOutput, batchIdx, runtimeStream); + } + else if (speculativeDecodingMode.isExplicitDraftTokens()) + { + newRequestExplicitDraftTokens(jointDecodingOutput, batchIdx, llmReq, runtimeStream); + } + else if (speculativeDecodingMode.isEagle()) + { + newRequestEagle(jointDecodingOutput, batchIdx, llmReq, modelConfig, decodingConfig, runtimeStream); + } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +} // namespace + std::tuple, std::vector> CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, @@ -563,9 +646,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon } inputIds->resize(decoderInputSize); - std::vector decoderRequests; - decoderRequests.reserve(finishedContextRequests.size()); - std::vector lookaheadPrompt; std::vector lookaheadAlgoConfigs; if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) @@ -597,36 +677,18 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon auto const promptLen = llmReq->getPromptLen(); - auto decoderRequest = decoder_batch::Request{promptLen}; - + SizeType32 numDecodingEngineTokens{1}; if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) { - if (llmReq->hasDraftTokens()) - { - auto const& draftTokens = llmReq->getDraftTokens(); - // Copy to pinned host memory (don't care about stream of bufferManager) - decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); - auto const& draftLogits = llmReq->getDraftLogits(); - if (draftLogits.has_value()) - { - decoderRequest.draftLogits - = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager); - } - decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1; - } - else - { - decoderRequest.generatedTokensPerEngineStep = 1; - } + numDecodingEngineTokens = llmReq->getNumDraftTokens() + 1; } else if (!modelConfig.getSpeculativeDecodingMode().isNone()) { - decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens(); + numDecodingEngineTokens = modelConfig.getMaxDecodingTokens(); } auto& dJointInput = decoderState.getJointDecodingInput(); - auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep; initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens, maxSequenceLength, decoderBufferManager); decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); @@ -667,16 +729,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon { TLLM_CHECK(beamWidth == 1); - if (modelConfig.getSpeculativeDecodingMode().isMedusa()) - { - TLLM_CHECK(medusaBuffers); - llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; - // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? - // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. - decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); - decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); - } - else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) + if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) { lookaheadPrompt.emplace_back(requestIds); @@ -684,67 +737,17 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); } - else if (modelConfig.getSpeculativeDecodingMode().isEagle()) - { - decoderRequest.eagleConfig - = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); - } - newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig, - decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, - decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + newRequestSpeculativeDecoding(decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), + batchSlot, *llmReq, decoderState.getSpeculativeDecodingMode(), numDecodingEngineTokens, + decoderState.getMaxDecodingEngineTokens(), medusaBuffers, modelConfig, worldConfig, decodingConfig, + mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, runtimeStream, decoderStream); } - decoderRequests.push_back(decoderRequest); - inputOffset += promptLen; } return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; } -std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig, - WorldConfig const& worldConfig, std::shared_ptr const& tensor, - BufferManager const& bufferManager) const -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - if (!mSpeculativeDecodingFastLogits) - { - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL); - } - - if (mIsLeaderInOrchMode) - { - te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo; - std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo)); - auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value(); - - // Broadcast to other ranks if needed - if (worldConfig.isTensorParallel()) - { - auto const& commSession = COMM_SESSION; - auto shape = logits->getShape(); - commSession.bcastValue(shape.d[0], 0); - commSession.bcastValue(shape.d[1], 0); - commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); - } - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; - } - - // Get logits from leader rank - auto const& commSession = COMM_SESSION; - int64_t dims[2]; - commSession.bcastValue(dims[0], 0); - commSession.bcastValue(dims[1], 0); - auto const logitsDtype = modelConfig.getLogitsDtype(); - auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype); - commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; -}; - } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp index 484cd7c3c7b..7234ca9ba57 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp @@ -121,8 +121,8 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS #endif // ENABLE_MULTI_DEVICE } -std::optional targetModelReceiveLogits( - executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig) +void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost, + executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype) { #if ENABLE_MULTI_DEVICE auto const& worldComm = tensorrt_llm::mpi::MpiComm::world(); @@ -151,10 +151,7 @@ std::optional targetModelReceiveLogits( int64_t dims[2]; MPICHECK(MPI_Mrecv(&dims, count, MPI_INT64_T, &msg, &status)); - auto const logitsDtype = modelConfig.getLogitsDtype(); - - auto tensor = tensorrt_llm::runtime::BufferManager::pinnedPool( - runtime::ITensor::makeShape({dims[0], dims[1]}), logitsDtype); + draftLogitsHost->reshape(runtime::ITensor::makeShape({dims[0], dims[1]})); worldComm.mprobe(fastLogitsInfo.draftParticipantId, mpi::MpiTag::kSpecDecLogitsData, &msg, &status); @@ -163,11 +160,7 @@ std::optional targetModelReceiveLogits( uint64_t const expectedSize = static_cast(dims[0]) * dims[1] * tc::getDTypeSize(logitsDtype); TLLM_CHECK((uint64_t) count == expectedSize); - MPICHECK(MPI_Mrecv(tensor->data(), count, MPI_UINT8_T, &msg, &status)); - - return tensor; -#else - return std::nullopt; + MPICHECK(MPI_Mrecv(draftLogitsHost->data(), count, MPI_UINT8_T, &msg, &status)); #endif // ENABLE_MULTI_DEVICE } diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h index 6d87ebee162..f19d5f5ef30 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h +++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h @@ -21,10 +21,8 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" -#include "tensorrt_llm/runtime/modelConfig.h" #include -#include namespace tensorrt_llm::batch_manager { @@ -52,7 +50,7 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS std::shared_ptr const& crossKvCacheManager, std::shared_ptr const& peftCacheManager); -std::optional targetModelReceiveLogits( - executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig); +void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost, + executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype); } // namespace tensorrt_llm::batch_manager::utils diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index abccbe60a13..b5851dc1c2d 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -131,6 +131,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( mSpeculativeDecodingMode = speculativeDecodingMode; + auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto& dInput = mJointDecodingInput; @@ -179,6 +180,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs; externalDraftTokensInputs.draftLogits = bufferManager.emptyTensor(MemoryType::kGPU, dtype); + externalDraftTokensInputs.draftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, dtype); externalDraftTokensInputs.draftProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.targetProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.numDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); @@ -187,8 +189,8 @@ void DecoderState::setupSpeculativeDecodingBuffers( = bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); externalDraftTokensInputs.useDraftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); - externalDraftTokensInputs.draftTokenIds - = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + externalDraftTokensInputs.draftTokenIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); + externalDraftTokensInputs.draftTokenIdsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvTokenIdType); dInput->externalDraftTokensInputs = externalDraftTokensInputs; } @@ -366,10 +368,16 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con {mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast(vocabSizePadded)}); dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape); dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape); - dInput.externalDraftTokensInputs->draftLogits->reshape( - ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)})); - dInput.externalDraftTokensInputs->draftTokenIds->reshape( - ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens})); + + auto const logitsShape = ITensor::makeShape( + {mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)}); + dInput.externalDraftTokensInputs->draftLogits->reshape(logitsShape); + dInput.externalDraftTokensInputs->draftLogitsHost->reshape(logitsShape); + + auto const tokenIdsShape = ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens}); + dInput.externalDraftTokensInputs->draftTokenIds->reshape(tokenIdsShape); + dInput.externalDraftTokensInputs->draftTokenIdsHost->reshape(tokenIdsShape); + dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape);