Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
5fa1914
[None][chore] Bump version to 1.1.0rc0 (#6651)
yiqingy0 Aug 7, 2025
85af621
[TRTLLM-6683][feat] Support LoRA reload CPU cache evicted adapter (#6…
amitz-nv Aug 7, 2025
6c1f7d8
[None][test] correct test-db context for perf yaml file (#6686)
ruodil Aug 7, 2025
8207d5f
[None] [feat] Add model gpt-oss (#6645)
hlu1 Aug 7, 2025
0a467b0
[https://nvbugs/5409414][fix] fix Not registered specs (#6660)
xinhe-nv Aug 7, 2025
8ec3b1d
[None][feat] : Add FP8 context MLA support for SM120 (#6059)
peaceh-nv Aug 7, 2025
c23e8e7
[TRTLLM-6092][doc] Add LoRA feature usage doc (#6603)
shaharmor98 Aug 7, 2025
1b9781e
[TRTLLM-6409][feat] Enable guided decoding with speculative decoding …
syuoni Aug 7, 2025
453a06e
[TRTLLM-6881][feat] Include attention dp rank info with KV cache even…
pcastonguay Aug 7, 2025
3c44b44
[None][infra] Fix guardwords (#6711)
EmmaQiaoCh Aug 7, 2025
46357e7
[None][package] Pin cuda-python version to >=12,<13 (#6702)
yiqingy0 Aug 7, 2025
0223de0
[None][doc] Add deployment guide section for VDR task (#6669)
nv-guomingz Aug 7, 2025
4055b76
[None][fix] disagg ctx pp4 + gen pp4 integ test (#6489)
raayandhar Aug 7, 2025
e968f98
[None][feat] Clean up ngram auto mode, add max_concurrency to configs…
mikeiovine Aug 7, 2025
3b2dd40
[None][chore] Remove py_executor from disagg gh team (#6716)
pcastonguay Aug 7, 2025
4ecda91
[https://nvbugs/5423962][fix] Address broken links (#6531)
chenopis Aug 7, 2025
db8dc97
[None][fix] Migrate to new cuda binding package name (#6700)
tongyuantongyu Aug 7, 2025
980929e
[https://nvbugs/5410687][fix] Hopper w4a8 groupwise MoE interleave (#…
symphonylyh Aug 7, 2025
8227616
[None][feat] Add NCCL Symmetric Integration for All Reduce (#4500)
Tabrizian Aug 8, 2025
efca359
[TRTLLM-6785][feat] BREAKING CHANGE Enable TRTLLM sampler by default …
dcampora Aug 8, 2025
88ced50
[TRTQA-2920][fix] Add failed cases into waives.txt (#6719)
xinhe-nv Aug 8, 2025
22f45a0
[TRTLLM-5252][test] add for mistral_small_3.1_24b perf test (#6685)
ruodil Aug 8, 2025
2f2f5cc
[TRTLLM-6744][feat] Remove input_sf swizzle for module WideEPMoE (#6231)
StudyingShao Aug 8, 2025
1cf6694
[None][fix] Fix unnecessary GPU synchronization in torch sampler caus…
zhanghaotong Aug 8, 2025
aee828d
[TRTLLM-6854][feat] Enable guided decoding with disagg serving (#6704)
syuoni Aug 8, 2025
064eb7a
[TRTLLM-5252][fix] Propagate mapping to intermediate layers (#6611)
2ez4bz Aug 8, 2025
b15d6fb
[None][test] fix yml condition error under qa folder (#6734)
ruodil Aug 8, 2025
9687bb4
[None][doc] Add doc for multimodal feature support matrix (#6619)
chang-l Aug 8, 2025
d913955
[TRTLLM-6898][feat] make fused_moe_cute_dsl work on blackwell (#6616)
limin2021 Aug 8, 2025
294e0d3
[https://nvbugs/5436461][infra] Adjust free_gpu_memory_fraction of te…
leslie-fang25 Aug 8, 2025
9ff4e75
[None][refactor] Combine resmooth_to_fp8_e8m0 and transform_sf_into_r…
yuxianq Aug 8, 2025
5f45227
[https://nvbugs/5437106][fix] Fix llama4 scout TRTLLM attn_backend (#…
JunyiXu-nv Aug 8, 2025
32ad7f3
[None][fix] Remove lock related typo in py_executor (#6653)
lancelly Aug 8, 2025
ebdc43e
[None][feat] move kv cache measure into transfer session (#6633)
zhengd-nv Aug 8, 2025
e251f7c
[None][fix]revert kvcache transfer (#6709)
chuangz0 Aug 8, 2025
b8f036f
[TRTLLM-6650][fix] Enhance CUDA graph + Beam search to correctly hand…
stnie Aug 8, 2025
d45236b
[TRTLLM-6308][feat] Support Aggregate mode for phi4-mm (#6184)
Wanli-Jiang Aug 8, 2025
90145cf
[None][feat] Optimize CUDA graph memory usage for spec decode cases (…
mikeiovine Aug 8, 2025
efcb8f7
[TRTLLM-7025] [infra] Reorganize CODEOWNERS to rectify `examples` map…
venkywonka Aug 8, 2025
cc0f4c8
[None][doc] Move AutoDeploy README.md to torch docs (#6528)
Fridah-nv Aug 8, 2025
d066750
[None][fix] WAR GPT OSS on H20 with Triton MOE (#6721)
dongfengy Aug 8, 2025
9778788
[TRTLLM-6420][feat] add support for Eclairv2 model - cherry-pick chan…
yibinl-nvidia Aug 9, 2025
bcf5ec0
[None][feat] Core Metrics Implementation (#5785)
hcyezhang Aug 9, 2025
d643aef
[Perf] Improve Llama4 performance for small max_seqlen cases (#6306)
nv-yilinf Aug 9, 2025
de47282
[TRTLLM-6637][feat] Resolve KV cache divergence issue (#6628)
ziyixiong-nv Aug 9, 2025
ee19ca5
[None][infra] Waive test main 0808 (#6751)
EmmaQiaoCh Aug 10, 2025
3c5aec1
[#5048][enhance] AutoDeploy: Optimize prepare_inputs (#6634)
galagam Aug 10, 2025
199f306
[None][chore][kv cache manager] Dead code elimination, we no longer r…
eopXD Aug 10, 2025
14b36e0
[TRTLLM-6174][feat] Enable FP32 mamba ssm cache (#6574)
shaharmor98 Aug 10, 2025
4142320
[https://nvbugs/5444937][fix] Fixing kv_cache_event unit test (#6753)
pcastonguay Aug 10, 2025
b6baa9e
[TRTLLM-6823][doc] Add checkpoint refactor docs (#6592)
shaharmor98 Aug 10, 2025
60073a7
[None][feat] Support SharedTensor on MultimodalParams (#6254)
yechank-nvidia Aug 11, 2025
4b4b91a
[None][feat] improve dataloading for benchmark_dataset by using batch…
zerollzeng Aug 11, 2025
767879e
[https://nvbugs/5431127][fix] Run test_disaggregated_deepseek_v3_lite…
bo-nv Aug 11, 2025
2cf31b5
relax tensor device type check to fix wideEP loading and fix argument
dongxuy04 Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[None][feat] : Add FP8 context MLA support for SM120 (#6059)
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
  • Loading branch information
peaceh-nv authored Aug 7, 2025
commit 8ec3b1de105aab755ce0eb930b435f7aa9bc9029
73 changes: 62 additions & 11 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr;
xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens;
xqaParams.is_fp8_output = mFP8ContextFMHA;
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
xqaParams.fp8_out_scale
= ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr);
// Parameters required for FP4 output.
xqaParams.output_sf = generationsParams.context_buf_sf;
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
Expand Down Expand Up @@ -736,10 +737,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
size_t const qk_buf_float_size
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
size_t const fp8_qkv_buffer_size
= mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_v_per_head = (mMLAParams.v_head_dim);

// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
int const total_k_dim_all_heads
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout

int const num_total_qkv_elements
= max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);

size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
if (mFP8ContextMLA)
{
fp8_qkv_buffer_size
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
}

size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
// Each token holds (batch_idx, token_idx_in_seq) int2.
Expand Down Expand Up @@ -1349,19 +1369,35 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
size_t const qk_buf_float_size = mEnableContextFMHA
? 0
: sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
size_t const fp8_qkv_buffer_size
= mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_v_per_head = (mMLAParams.v_head_dim);

// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
int const total_k_dim_all_heads
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
int const num_total_qkv_elements
= params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
if (mFP8ContextMLA)
{
fp8_qkv_buffer_size
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
}
size_t const padding_offset_size
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
size_t const encoder_padding_offset_size
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.cross_kv_length;
// Each token holds (batch_idx, token_idx_in_seq) int2.
size_t const tokens_info_size = sizeof(int2) * params.num_tokens;
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0;
size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;

// cp workspace size upper bound
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1);
Expand Down Expand Up @@ -1608,6 +1644,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
params.mla_param->cache_type = cache_type;
params.mla_param->cu_q_seqlens = cu_q_seqlens;
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
// Set BMM scales for FP8 context computation
params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr;
params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr;
params.mla_param->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale;
params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr;
// Set additional scales for context phase
params.mla_param->quant_scale_o = params.attention_output_orig_quant;
params.mla_param->dequant_scale_q = params.kv_scale_quant_orig;
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
if (mPagedContextFMHA && mPagedKVCache)
{
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
Expand Down Expand Up @@ -1686,8 +1731,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
// TODO: set it correctly for contiguous kv buffer (cross-attention).
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
// Device buffer pointers.
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(attention_input);
fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(attention_input);
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
// TODO: add contiguous kv buffer (cross-attention).
fmhaParams.kvPtr = nullptr;
Expand Down Expand Up @@ -2487,7 +2532,7 @@ int AttentionOp::initialize() noexcept
}

// FP8 FMHA should be used with fp8 workflow together.
if (mFP8ContextFMHA)
if (mFP8ContextFMHA || mFP8ContextMLA)
{
data_type = DATA_TYPE_E4M3;
}
Expand Down Expand Up @@ -2520,6 +2565,11 @@ int AttentionOp::initialize() noexcept
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
fmhaParams.dataTypeKv = DATA_TYPE_BF16;
}
if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache())
{
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
}
// TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
fmhaParams.forceFp32Acc = false;
Expand Down Expand Up @@ -2573,7 +2623,7 @@ int AttentionOp::initialize() noexcept
// Deepseek-V2 Generation needs a differ fmha with different argumments
if (mIsMLAEnabled)
{
mEnableXQA = (mSM == kSM_120);
mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA;
if (mUseTllmGen)
{
Data_type qDataType = DATA_TYPE_FP32;
Expand Down Expand Up @@ -2836,6 +2886,7 @@ std::string AttentionOp::toString() const
ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;
ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl;
ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ class AttentionOp
bool mPosShiftEnabled = false;
bool mPagedContextFMHA = false;
bool mFP8ContextFMHA = false;
bool mFP8ContextMLA = false;
bool mFP8GenerationMLA = false;
bool mDenseContextFMHA = false;
bool mHasFullAttentionMask = false;
Expand Down
54 changes: 54 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,49 @@ void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, c
<<<grid, 256, 0, stream>>>(params.attention_input_buf, params.latent_cache, kv_cache_buffer,
params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens,
params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv);
if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr
&& params.cache_type == KvCacheDataType::FP8)
{
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8");

int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
int const dim_v_per_head = (params.meta.v_head_dim);

// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = params.head_num * dim_q_per_head;
int const total_k_dim_all_heads
= params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout

int const num_total_qkv_elements
= params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim;
float const* device_qkv_scale_ptr = params.quant_scale_qkv;

if (num_total_qkv_elements > 0)
{
int const threads_per_block = 256;
int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block;

TLLM_LOG_DEBUG(
"Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d",
num_blocks, threads_per_block, num_total_qkv_elements);

tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
static_cast<T const*>(params.attention_input_buf), // Source
static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination
num_total_qkv_elements, device_qkv_scale_ptr);
sync_check_cuda_error(stream);

cudaStreamSynchronize(stream);
}
else
{
TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization.");
}
}
}

template <typename T, typename KVCacheBuffer>
Expand Down Expand Up @@ -1037,6 +1080,17 @@ INSTANTIATE_SET_KVCACHE_MLA(float);
INSTANTIATE_SET_KVCACHE_MLA(half);
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);

template <typename T_IN>
__global__ void QuantizeCopyInputToFp8Kernel(
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr)
{
uint element_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (element_idx < num_total_elements)
{
float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f;
output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor);
}
}
} // namespace kernels

} // namespace tensorrt_llm
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ struct MlaParams
void* context_paged_kv_ptr = nullptr;
void* context_kv_cache_block_offsets_ptr = nullptr;
int32_t context_paged_kv_max_blocks_per_seq = 0;
// for FP8 context qkv quantization
float const* quant_scale_qkv = nullptr;
};

template <typename T, typename KVCacheBuffer>
Expand All @@ -111,5 +113,9 @@ void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* late
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);

template <typename T_IN>
__global__ void QuantizeCopyInputToFp8Kernel(
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr);

} // namespace kernels
} // namespace tensorrt_llm
Loading
Loading