forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrtllm.py
More file actions
1928 lines (1795 loc) · 90.8 KB
/
trtllm.py
File metadata and controls
1928 lines (1795 loc) · 90.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import math
import os
import weakref
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
if TYPE_CHECKING:
from ..speculative.utils import SpecDecodingTensor
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.internal import thop
from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
from ..utils import (compute_swizzled_sf_shape, get_global_attrs,
get_model_extra_attrs)
from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
AttentionMetadata, KVCacheParams, MLAParams,
PositionalEmbeddingParams, PredefinedAttentionMask,
RopeParams)
@dataclass(kw_only=True, init=False)
class TrtllmAttentionWrapper:
sequence_length: torch.Tensor
host_past_key_value_lengths: torch.Tensor
host_total_kv_lens: torch.Tensor
context_lengths: torch.Tensor
host_context_lengths: torch.Tensor
host_request_types: torch.Tensor
kv_cache_block_offsets: torch.Tensor
host_kv_cache_block_offsets: torch.Tensor
host_kv_cache_pool_pointers: torch.Tensor
host_kv_cache_pool_mapping: torch.Tensor
workspace: Optional[torch.Tensor]
cache_indirection: Optional[torch.Tensor]
kv_scale_orig_quant: Optional[torch.Tensor]
kv_scale_quant_orig: Optional[torch.Tensor]
out_scale: Optional[torch.Tensor]
rotary_inv_freq: Optional[torch.Tensor]
rotary_cos_sin: Optional[torch.Tensor]
layer_idx: int
num_heads: int
num_kv_heads: int
head_size: int
tokens_per_block: int
max_num_requests: int
max_context_length: int
attention_window_size: int
sink_token_length: int
beam_width: int
predicted_tokens_per_seq: int
quant_mode: int
position_embedding_type: int
rotary_embedding_dim: int
rotary_embedding_base: float
rotary_embedding_scale_type: int
rotary_embedding_scale: float
rotary_embedding_short_m_scale: float
rotary_embedding_long_m_scale: float
rotary_embedding_max_positions: int
rotary_embedding_original_max_positions: int
use_paged_context_fmha: bool
is_mla_enable: bool
q_lora_rank: Optional[int]
kv_lora_rank: Optional[int]
qk_rope_head_dim: Optional[int]
qk_nope_head_dim: Optional[int]
v_head_dim: Optional[int]
chunked_prefill_buffer_batch_size: Optional[int]
attention_chunk_size: Optional[int]
softmax_stats_tensor: Optional[torch.Tensor]
use_spec_decoding: bool
is_spec_dec_tree: bool
spec_decoding_position_offsets: Optional[torch.Tensor]
spec_decoding_packed_mask: Optional[torch.Tensor]
spec_decoding_generation_lengths: Optional[torch.Tensor]
spec_decoding_bl_tree_mask_offset: Optional[torch.Tensor]
spec_decoding_bl_tree_mask: Optional[torch.Tensor]
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor]
kwargs: dict
def __init__(
self,
num_heads: int,
head_size: int,
num_kv_heads: Optional[int] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
q_scaling: Optional[float] = None,
mla_params: Optional[MLAParams] = None,
attention_chunk_size: Optional[int] = None,
**kwargs,
):
"""
Initialize the attention wrapper.
Args:
num_heads (int): The number of query heads.
head_dim (int): The size of each attention head (hidden_size // num_heads).
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied.
"""
rope_params = None
if pos_embd_params is not None:
rope_params = pos_embd_params.rope
rope_params = rope_params or RopeParams()
self.rope_params = rope_params
self.is_mla_enable = mla_params is not None
self.q_scaling = q_scaling or 1.0
self.predicted_tokens_per_seq = 1
self.attention_chunk_size = attention_chunk_size
if self.is_mla_enable:
self.q_lora_rank = mla_params.q_lora_rank
self.kv_lora_rank = mla_params.kv_lora_rank
self.qk_nope_head_dim = mla_params.qk_nope_head_dim
self.qk_rope_head_dim = mla_params.qk_rope_head_dim
self.v_head_dim = mla_params.v_head_dim
self.predicted_tokens_per_seq = mla_params.predicted_tokens_per_seq
else:
self.q_lora_rank = None
self.kv_lora_rank = None
self.qk_nope_head_dim = None
self.qk_rope_head_dim = None
self.v_head_dim = None
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
)
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.head_size = head_size
self.position_embedding_type = int(
pos_embd_params.type) if pos_embd_params is not None else 0
self.rotary_embedding_dim = rope_params.dim
self.rotary_embedding_base = rope_params.theta
self.rotary_embedding_scale_type = int(rope_params.scale_type)
self.rotary_embedding_scale = rope_params.scale
self.rotary_embedding_short_m_scale = rope_params.short_m_scale
self.rotary_embedding_long_m_scale = rope_params.long_m_scale
self.rotary_embedding_max_positions = rope_params.max_positions
self.rotary_embedding_original_max_positions = rope_params.original_max_positions
self.kwargs = {}
self.kwargs.update(kwargs)
def update_quant_config(self, quant_config: Optional[QuantConfig] = None):
quant_config = quant_config or QuantConfig()
self.quant_mode = int(quant_config.layer_quant_mode)
def plan(
self,
*,
layer_idx: int = 0,
tokens_per_block: Optional[int] = None,
max_num_requests: int = 0,
max_sequence_length: int = 0,
max_context_length: int = 0,
attention_window_size: Optional[int] = None,
sink_token_length: int = 0,
beam_width: int = 1,
sequence_length: torch.Tensor = ...,
host_past_key_value_lengths: torch.Tensor = ...,
host_total_kv_lens: torch.Tensor = ...,
context_lengths: torch.Tensor = ...,
host_context_lengths: torch.Tensor = ...,
host_request_types: torch.Tensor = ...,
kv_cache_block_offsets: Optional[torch.Tensor] = None,
host_kv_cache_block_offsets: Optional[torch.Tensor] = None,
host_kv_cache_pool_pointers: Optional[torch.Tensor] = None,
host_kv_cache_pool_mapping: Optional[torch.Tensor] = None,
block_ids_per_seq: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
cache_indirection: Optional[torch.Tensor] = None,
kv_scale_orig_quant: Optional[torch.Tensor] = None,
kv_scale_quant_orig: Optional[torch.Tensor] = None,
out_scale: Optional[torch.Tensor] = None,
out_scale_sf: Optional[torch.Tensor] = None,
kv_scales_sf: Optional[torch.Tensor] = None,
kv_scales_sf_inv: Optional[torch.Tensor] = None,
use_nvfp4_output: bool = False,
use_paged_context_fmha: bool = False,
attention_input_type: AttentionInputType = AttentionInputType.mixed,
latent_cache: Optional[torch.Tensor] = None,
q_pe: Optional[torch.Tensor] = None,
mrope_config: Optional[dict] = None,
softmax_stats_tensor: Optional[torch.Tensor] = None,
is_spec_decoding_enabled: bool = False,
use_spec_decoding: bool = False,
is_spec_dec_tree: bool = False,
spec_decoding_position_offsets: Optional[torch.Tensor] = None,
spec_decoding_packed_mask: Optional[torch.Tensor] = None,
spec_decoding_generation_lengths: Optional[torch.Tensor] = None,
spec_decoding_bl_tree_mask_offset: Optional[torch.Tensor] = None,
spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None,
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
chunked_prefill_buffer_batch_size: int = 1,
sparse_kv_indices: Optional[torch.Tensor] = None,
sparse_kv_offsets: Optional[torch.Tensor] = None,
sparse_attn_indices: Optional[torch.Tensor] = None,
sparse_attn_offsets: Optional[torch.Tensor] = None,
sparse_attn_indices_block_size: int = 1,
sparse_mla_topk: int = 0,
helix_position_offsets: Optional[torch.Tensor] = None,
helix_is_inactive_rank: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Plan the attention operation.
Call this method without arguments can reset the planned states.
For required arguments, can use ellipsis (...) as default value to represent invalid states.
Args:
layer_idx (int): The index of the attention layer in the model.
tokens_per_block (int): Token number per KV cache block.
max_num_requests (int): Max request number per batch.
max_sequence_length (int): Max sequence length.
max_context_length (int): Max context length per context-phase sequence.
attention_window_size (int): Max token number attended in windowed attention.
sink_token_length (int): Sink token number in StreamingLLM.
beam_width (int): Beam width in beam search.
sequence_length (torch.Tensor): The length of each sequence with shape (batch_size) on GPU.
host_past_key_value_lengths (torch.Tensor): Same as sequence_length, but on CPU.
host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
context_lengths (torch.Tensor): The context-phase sequence length of each request with shape (batch_size) on GPU.
host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU.
host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU.
kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
host_kv_cache_block_offsets (torch.Tensor): Same as kv_cache_block_offsets, but on CPU.
host_kv_cache_pool_pointers (torch.Tensor): The pointers to the KV cache pools on CPU, its shape is (num_pools, 2), one for primary pool in GPU memory, one for secondary pool in CPU memory.
host_kv_cache_pool_mapping (torch.Tensor): The index of the pool used by each attention layer on CPU, its shape is (num_local_attention_layers). The local attention layers mean all attention layers in the current PP stage in the pipeline parallelism case.
workspace (torch.Tensor): An optional workspace tensor on GPU.
cache_indirection (torch.Tensor): A tensor for beam search on GPU, its shape is (batch_size, beam_width, max_seqlen), for a sequence si, a beam bi and a token ti, the element cache_indirection[si][bi][ti] is an integer between 0 and beam_width-1 that indicates which path in the beam to read the K and V elements from in the KV cache.
kv_scale_orig_quant (torch.Tensor): The tensor to store the scaling factor for quantization to INT8/FP8 in the KV cache, with shape (1) on GPU.
kv_scale_quant_orig (torch.Tensor): The tensor to store the scaling factor for dequantization from INT8/FP8 in the KV cache, with shape (1) on GPU.
out_scale (torch.Tensor): The tensor to store the scaling factor to quantize output, with shape (1) on GPU.
out_scale_sf (torch.Tensor): The tensor to store the global scale for NVFP4 scaling factors, with shape (1) on GPU.
kv_scales_sf (torch.Tensor): The tensor to store the global scale for KV NVFP4 scaling factors, with shape (2) on GPU.
kv_scales_sf_inv (torch.Tensor): The tensor to store the inverse of the global scale for KV NVFP4 scaling factors, with shape (2) on GPU.
use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner.
mrope_config (dict): The dictionary containing the mRope configuration.
softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum)
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens.
sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention.
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
helix_is_inactive_rank (torch.Tensor): For Helix: whether the current rank is inactive, with shape (batch_size) on GPU.
"""
self.layer_idx = layer_idx
self.tokens_per_block = tokens_per_block
self.max_num_requests = max_num_requests
self.max_context_length = max_context_length
self.attention_window_size = attention_window_size or max_sequence_length
self.sink_token_length = sink_token_length
self.beam_width = beam_width
self.sequence_length = sequence_length
self.host_past_key_value_lengths = host_past_key_value_lengths
self.host_total_kv_lens = host_total_kv_lens
self.context_lengths = context_lengths
self.host_context_lengths = host_context_lengths
self.host_request_types = host_request_types
self.kv_cache_block_offsets = kv_cache_block_offsets
self.host_kv_cache_block_offsets = host_kv_cache_block_offsets
self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers
self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping
self.workspace = workspace
self.cache_indirection = cache_indirection
self.kv_scale_orig_quant = kv_scale_orig_quant if kv_scales_sf_inv is None else kv_scales_sf_inv
self.kv_scale_quant_orig = kv_scale_quant_orig if kv_scales_sf is None else kv_scales_sf
self.out_scale = out_scale
self.out_scale_sf = out_scale_sf
self.use_paged_context_fmha = use_paged_context_fmha
self.use_nvfp4_output = use_nvfp4_output
self.attention_input_type = int(attention_input_type)
self.latent_cache = latent_cache
self.q_pe = q_pe
self.mrope_rotary_cos_sin = mrope_config.get(
'mrope_rotary_cos_sin') if mrope_config is not None else None
self.mrope_position_deltas = mrope_config.get(
'mrope_position_deltas') if mrope_config is not None else None
self.block_ids_per_seq = block_ids_per_seq
self.softmax_stats_tensor = softmax_stats_tensor
self.attention_sinks = attention_sinks
self.sparse_kv_indices = sparse_kv_indices
self.sparse_kv_offsets = sparse_kv_offsets
self.sparse_attn_indices = sparse_attn_indices
self.sparse_attn_offsets = sparse_attn_offsets
self.sparse_attn_indices_block_size = sparse_attn_indices_block_size
self.sparse_mla_topk = sparse_mla_topk
self.helix_position_offsets = helix_position_offsets
self.helix_is_inactive_rank = helix_is_inactive_rank
if self.helix_is_inactive_rank is not None and not isinstance(
self.helix_is_inactive_rank, torch.Tensor):
self.helix_is_inactive_rank = torch.tensor(
self.helix_is_inactive_rank, dtype=torch.bool, pin_memory=True)
if max_sequence_length > self.rope_params.max_positions:
self.rope_params.max_positions = max_sequence_length
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
)
self.is_spec_decoding_enabled = is_spec_decoding_enabled
self.use_spec_decoding = use_spec_decoding
self.is_spec_dec_tree = is_spec_dec_tree
self.spec_decoding_position_offsets = spec_decoding_position_offsets
self.spec_decoding_packed_mask = spec_decoding_packed_mask
self.spec_decoding_generation_lengths = spec_decoding_generation_lengths
self.spec_decoding_bl_tree_mask_offset = spec_decoding_bl_tree_mask_offset
self.spec_decoding_bl_tree_mask = spec_decoding_bl_tree_mask
self.spec_bl_tree_first_sparse_mask_offset_kv = spec_bl_tree_first_sparse_mask_offset_kv
self.chunked_prefill_buffer_batch_size = chunked_prefill_buffer_batch_size
self.kwargs.update(kwargs)
def create_output(self, q: torch.Tensor, out_dtype: torch.dtype):
num_tokens = q.size(0)
attention_input_type = (AttentionInputType(self.attention_input_type)
if self.attention_input_type is not None else
AttentionInputType.mixed)
if out_dtype is None:
out_dtype = q.dtype
is_gen_only = attention_input_type == AttentionInputType.generation_only
v_head_size = self.head_size
if self.is_mla_enable:
v_head_size = self.kv_lora_rank if is_gen_only else self.v_head_dim
if out_dtype == torch.uint8:
num_nvfp4_elements_per_container = 2
scaling_vector_size = 16
size_per_token = self.num_heads * v_head_size
output = q.new_empty(
(num_tokens,
size_per_token // num_nvfp4_elements_per_container),
dtype=torch.uint8)
# Create a sf (scaling factors) tensor for NVFP4 (use INT8 as the container dtype).
padded_row, padded_col = compute_swizzled_sf_shape(
num_tokens, size_per_token // scaling_vector_size)
output_sf = q.new_empty(padded_row * padded_col, dtype=torch.uint8)
else:
output = q.new_empty((num_tokens, self.num_heads * v_head_size),
dtype=out_dtype)
output_sf = None
return output, output_sf
def run(
self,
q: torch.Tensor,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
output: Optional[torch.Tensor] = None,
output_sf: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
is_fused_qkv: bool = True,
update_kv_cache: bool = True,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
cu_q_seqlens: Optional[torch.Tensor] = None,
cu_kv_seqlens: Optional[torch.Tensor] = None,
fmha_scheduler_counter: Optional[torch.Tensor] = None,
mla_bmm1_scale: Optional[torch.Tensor] = None,
mla_bmm2_scale: Optional[torch.Tensor] = None,
quant_q_buffer: Optional[torch.Tensor] = None,
):
"""
Run the attention operation.
Args:
q (torch.Tensor): Query tensor with shape (num_tokens, num_heads * head_dim) or QKV tensor with shape (num_tokens, (num_heads + 2 * num_kv_heads) * head_dim).
k (Optional[torch.Tensor]): Key tensor with shape (num_tokens, num_kv_heads * head_dim) or None if QKV tensor is provided.
v (Optional[torch.Tensor]): Value tensor with shape (num_tokens, num_kv_heads * head_dim) or None if QKV tensor is provided.
out_dtype (Optional[torch.dtype]): Output data type if provided.
is_fused_qkv (bool): Whether QKV tensor is provided.
update_kv_cache (bool): Whether KV cache is updated.
attention_mask (AttentionMask): Attention mask. See definition of AttentionMask for accepted types. Defaults to predefined causal mask.
Returns:
torch.Tensor with shape (num_tokens, num_heads * head_dim).
"""
if len(self.kwargs) > 0:
logger.warning(
f"unknown arguments {list(self.kwargs.keys())} in attention wrapper"
)
assert (is_fused_qkv and k is None
and v is None) or (not is_fused_qkv and k is not None
and v is not None)
if not self.is_mla_enable:
if is_fused_qkv:
qkv_hidden_size = (self.num_heads +
2 * self.num_kv_heads) * self.head_size
assert q.shape[1] == qkv_hidden_size
else:
q_hidden_size = self.num_heads * self.head_size
assert q.shape[1] == q_hidden_size
if update_kv_cache:
kv_hidden_size = self.num_kv_heads * self.head_size
assert k.shape[1] == kv_hidden_size
assert v.shape[1] == kv_hidden_size
num_tokens = q.shape[0]
if k is not None:
assert k.shape[0] == num_tokens
assert v.shape[0] == num_tokens
batch_size = self.sequence_length.shape[0]
assert self.host_past_key_value_lengths.shape[0] == batch_size
assert self.context_lengths.shape[0] == batch_size
assert self.host_context_lengths.shape[0] == batch_size
assert self.host_request_types.shape[0] == batch_size
if attention_mask == PredefinedAttentionMask.CAUSAL:
mask_type = AttentionMaskType.causal
elif attention_mask == PredefinedAttentionMask.FULL:
mask_type = AttentionMaskType.padding
else:
raise ValueError("Unexpected attention mask type")
else:
# For DSA, use the same qkv hidden size for context and generation phases
is_sparse_attn = self.sparse_attn_indices is not None and self.sparse_attn_indices.numel(
) > 0
if self.attention_input_type == AttentionInputType.context_only and is_sparse_attn:
assert is_fused_qkv
qkv_hidden_size = self.num_heads * (self.kv_lora_rank +
self.qk_rope_head_dim)
elif self.attention_input_type == AttentionInputType.context_only:
assert not is_fused_qkv
qkv_hidden_size = self.num_heads * (self.qk_nope_head_dim +
self.qk_rope_head_dim)
elif self.attention_input_type == AttentionInputType.generation_only:
assert is_fused_qkv
qkv_hidden_size = self.num_heads * (self.kv_lora_rank +
self.qk_rope_head_dim)
else:
raise ValueError(
"In MLA, TrtllmAttention can only support context_only or generation_only, not mixed."
)
assert q.shape[
1] == qkv_hidden_size, f"q.shape[1] must be equal to qkv_hidden_size, got {q.shape[1]=}, {qkv_hidden_size=}"
batch_size = self.sequence_length.shape[0]
assert self.host_past_key_value_lengths.shape[0] == batch_size
assert self.context_lengths.shape[0] == batch_size
assert self.host_context_lengths.shape[0] == batch_size
assert self.host_request_types.shape[0] == batch_size
if attention_mask == PredefinedAttentionMask.CAUSAL:
mask_type = AttentionMaskType.causal
elif attention_mask == PredefinedAttentionMask.FULL:
mask_type = AttentionMaskType.padding
else:
raise ValueError("Unexpected attention mask type")
if output is None:
assert output_sf is None
output, output_sf = self.create_output(q, out_dtype)
else:
# output is provided, expect output_sf be provided as well if has NVFP4 output.
assert out_dtype is None or out_dtype != torch.uint8 or output_sf is not None
# packing parameters to avoid maxing out 64 arguments
rotary_embedding_scales = [
self.rotary_embedding_scale, self.rotary_embedding_short_m_scale,
self.rotary_embedding_long_m_scale
]
rotary_embedding_max_position_info = [
self.rotary_embedding_max_positions,
self.rotary_embedding_original_max_positions
]
spec_decoding_bool_params = [
self.is_spec_decoding_enabled, self.use_spec_decoding,
self.is_spec_dec_tree
]
spec_decoding_tensor_params = [
self.spec_decoding_generation_lengths,
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
]
if get_sm_version() >= 100:
spec_decoding_tensor_params.append(
self.spec_decoding_bl_tree_mask_offset)
spec_decoding_tensor_params.append(self.spec_decoding_bl_tree_mask)
spec_decoding_tensor_params.append(
self.spec_bl_tree_first_sparse_mask_offset_kv)
mla_tensor_params = [
self.helix_position_offsets, self.helix_is_inactive_rank
]
thop.attention(
q,
k,
v,
output,
output_sf,
out_dtype,
self.workspace,
self.sequence_length,
self.host_past_key_value_lengths,
self.host_total_kv_lens,
self.context_lengths,
self.host_context_lengths,
self.host_request_types,
self.kv_cache_block_offsets,
self.host_kv_cache_block_offsets,
self.host_kv_cache_pool_pointers,
self.host_kv_cache_pool_mapping,
self.cache_indirection,
self.kv_scale_orig_quant,
self.kv_scale_quant_orig,
self.out_scale_sf if self.use_nvfp4_output else self.out_scale,
self.rotary_inv_freq,
self.rotary_cos_sin,
self.latent_cache,
self.q_pe,
self.block_ids_per_seq,
self.attention_sinks,
is_fused_qkv,
update_kv_cache,
self.predicted_tokens_per_seq,
self.layer_idx,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.tokens_per_block,
self.max_num_requests,
self.max_context_length,
self.attention_window_size,
self.sink_token_length,
self.beam_width,
int(mask_type),
self.quant_mode,
self.q_scaling,
self.position_embedding_type,
self.rotary_embedding_dim,
self.rotary_embedding_base,
self.rotary_embedding_scale_type,
rotary_embedding_scales,
rotary_embedding_max_position_info,
self.use_paged_context_fmha,
self.attention_input_type,
self.is_mla_enable,
self.chunked_prefill_buffer_batch_size,
self.q_lora_rank,
self.kv_lora_rank,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.mrope_rotary_cos_sin,
self.mrope_position_deltas,
mla_tensor_params,
self.attention_chunk_size,
self.softmax_stats_tensor,
spec_decoding_bool_params,
spec_decoding_tensor_params,
self.sparse_kv_indices,
self.sparse_kv_offsets,
self.sparse_attn_indices,
self.sparse_attn_offsets,
self.sparse_attn_indices_block_size,
self.sparse_mla_topk,
cu_q_seqlens,
cu_kv_seqlens,
fmha_scheduler_counter,
mla_bmm1_scale,
mla_bmm2_scale,
quant_q_buffer,
)
# reset the planned states (especially tensors) to avoid memory leak
self.plan()
return output, output_sf
def is_nvfp4_output_kernel_available(
self,
*,
tokens_per_block: Optional[int] = None,
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
CAUSAL,
use_paged_context_fmha: bool = False,
is_mla_enable: bool = False,
**kwargs,
):
"""
Runtime check whether the NVFP4 output kernel is available.
Args:
tokens_per_block (int): Token number per KV cache block.
attention_mask (PredefinedAttentionMask): The attention mask type.
use_paged_context_fmha (bool): Whether to use paged context FMHA.
is_mla_enable (bool): Whether to use MLA.
"""
if attention_mask == PredefinedAttentionMask.CAUSAL:
mask_type = AttentionMaskType.causal
elif attention_mask == PredefinedAttentionMask.FULL:
mask_type = AttentionMaskType.padding
else:
raise ValueError("Unexpected attention mask type")
return torch.ops.trtllm.attention_supports_nvfp4_output(
self.num_heads,
self.num_kv_heads,
self.head_size,
tokens_per_block,
int(mask_type),
self.quant_mode,
use_paged_context_fmha,
is_mla_enable,
)
@dataclass(kw_only=True)
class TrtllmAttentionMetadata(AttentionMetadata):
workspace: Optional[torch.Tensor] = None
cuda_graph_workspace: Optional[torch.Tensor] = None
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
# when beam search is enabled.
beam_width: int = 1
# TrtllmAttention needs to know the max sequence length.
# Implemented as a property to support no cache mode.
max_seq_len: Optional[int]
# Storage for internal max_seq_len value
_max_seq_len_storage: Optional[int] = field(default=None,
init=True,
repr=False)
# Flags to enable spec-dec mode (multi-query mode) in TRTLLM XQA Kernels
# spec decoding mode can be enabled for non-TRTLLM-gen kernels (pre-Blackwell XQA kernels)
# is_spec_decoding_enabled specifies if spec-dec mode is supported for the entire runtime.
is_spec_decoding_enabled: bool = False
# use_spec_decoding determines if the attention layer should be run in spec-dec mode at the specific step / layer.
use_spec_decoding: bool = False
# if spec-dec tree is a tree or a chain (linear tree)
is_spec_dec_tree: bool = False
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
is_spec_dec_dynamic_tree: bool = False
# parameters required for spec-dec mode
spec_decoding_position_offsets: Optional[torch.Tensor] = None
spec_decoding_packed_mask: Optional[torch.Tensor] = None
spec_decoding_generation_lengths: Optional[torch.Tensor] = None
spec_decoding_bl_tree_mask_offset: Optional[torch.Tensor] = None
spec_decoding_bl_tree_mask: Optional[torch.Tensor] = None
spec_bl_tree_first_sparse_mask_offset_kv: Optional[torch.Tensor] = None
# Whether the current rank is inactive for helix parallelism.
# In helix parallelism, only the active rank appends KV cache for the query token
# and attends to the previously cached tokens as well as the query token. Inactive
# ranks do not append KV cache for the query token and attend to the previously
# cached tokens only.
helix_is_inactive_rank: Optional[torch.Tensor] = None
@property
def max_seq_len(self) -> int:
"""
Returns the max sequence length.
If the attention uses KV cache, it will return max_seq_len from the KV cache manager.
If the attention is no cache, max_seq_len should be set manually by user.
"""
if self.kv_cache_manager is not None:
return self.kv_cache_manager.max_seq_len
else:
assert self._max_seq_len_storage is not None, "max_seq_len should be set for no kv cache attention"
return self._max_seq_len_storage
@max_seq_len.setter
def max_seq_len(self, value: int) -> None:
"""
Set the max sequence length for no cache attention.
"""
self._max_seq_len_storage = value
@property
def tokens_per_block(self) -> Optional[int]:
"""
Returns the number of tokens per block from the KV cache manager.
"""
return self.kv_cache_manager.tokens_per_block if self.kv_cache_manager is not None else None
@property
def host_kv_cache_pool_pointers(self) -> Optional[torch.Tensor]:
"""
Returns the host KV cache pool pointers from the KV cache manager if KV cache manager is not None.
"""
return self.kv_cache_manager.kv_cache_pool_pointers if self.kv_cache_manager is not None else None
@property
def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
"""
Returns the host KV cache pool mapping from the KV cache manager if KV cache manager is not None.
"""
return self.kv_cache_manager.kv_cache_pool_mapping if self.kv_cache_manager is not None else None
def __post_init__(self) -> None:
super().__post_init__()
self._post_init_with_buffers(self.cuda_graph_buffers)
def _post_init_with_buffers(self, buffers) -> None:
# Set a default value, as max_num_sequences is not always set.
if self.max_num_sequences is None:
self.max_num_sequences = self.max_num_requests
capture_graph = torch.cuda.is_current_stream_capturing()
self.prompt_lens_cuda = self.get_empty(
buffers,
(self.max_num_sequences, ),
cache_name="prompt_lens_cuda",
dtype=torch.int,
capture_graph=capture_graph,
)
self.prompt_lens_cpu = torch.empty_like(
self.prompt_lens_cuda,
device='cpu',
pin_memory=True,
)
self.kv_lens_cuda = self.get_empty_like(
buffers,
self.prompt_lens_cuda,
cache_name="kv_lens_cuda",
capture_graph=capture_graph,
)
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
device='cpu',
pin_memory=True)
self.host_total_kv_lens = torch.empty(2, device='cpu', dtype=torch.int)
self.host_request_types = torch.empty_like(self.prompt_lens_cpu)
# For debugging, can use it to call the wrapper's plan function
if self.workspace is None:
self.workspace = torch.empty(
(0, ),
device='cuda',
dtype=torch.int8,
)
if self.cuda_graph_workspace is None:
self.cuda_graph_workspace = torch.empty(
(0, ),
device='cuda',
dtype=torch.int8,
)
if self.kv_cache_manager is not None:
self.kv_cache_block_offsets = self.get_empty(
buffers,
[
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.host_kv_cache_block_offsets = torch.empty_like(
self.kv_cache_block_offsets,
device='cpu',
pin_memory=True,
)
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
if self.enable_flash_mla:
self.block_ids_per_seq = self.get_empty(
buffers,
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="block_ids_per_seq",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.kv_block_ids_per_seq = self.get_empty(
buffers,
[
self.kv_cache_manager.max_batch_size,
self.kv_cache_manager.max_blocks_per_seq
],
cache_name="kv_block_ids_per_seq",
dtype=torch.int32,
capture_graph=capture_graph,
)
if self.enable_context_mla_with_cached_kv:
# for kv cache reuse/chunked context in MLA
self.ctx_cached_token_indptr = self.get_empty(
buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_cached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_cached_token_indptr = torch.zeros_like(
self.ctx_cached_token_indptr,
device='cpu',
pin_memory=True,
)
self.ctx_uncached_token_indptr = self.get_empty(
buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_uncached_token_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_uncached_token_indptr = torch.zeros_like(
self.ctx_uncached_token_indptr,
device='cpu',
pin_memory=True,
)
# context full seqlens include cached tokens and uncached tokens
self.ctx_kv_indptr = self.get_empty(
buffers,
(self.max_num_requests + 1, ),
cache_name="ctx_kv_indptr",
dtype=torch.int64,
capture_graph=capture_graph,
)
self.host_ctx_kv_indptr = torch.zeros_like(
self.ctx_kv_indptr,
device='cpu',
pin_memory=True,
)
def on_update_kv_lens(self):
# After changing the kv_lens/kv_lens_cuda, we may need to update other metadata.
# Especially for the changes in the _preprocess_inputs() of model_engine.py.
pass
def prepare(self) -> None:
extra_attrs = get_model_extra_attrs()
# If model extra attrs is set, attention_metadata is setup in executor.
if extra_attrs is None:
get_global_attrs().attention_metadata = weakref.ref(self)
if self.kv_cache_manager is None:
# Convert the attention metadata to a TRT-LLM no cache attention metadata.
assert self.kv_cache_manager is None, "no cache attention should not have KV cache manager"
assert self._max_seq_len_storage is not None, "max_seq_len should be set for no cache attention"
# setting kv cache params
self.kv_cache_params = KVCacheParams(use_cache=False, )
# trtllm attn metadata prepare() requires this
self.prompt_lens = self.context_lens
# set params that are used in wrapper.plan()
self.kv_cache_block_offsets = None
self.host_kv_cache_block_offsets = None
self.block_ids_per_seq = None
prompt_lens = torch.tensor(
self.prompt_lens,
dtype=torch.int,
device='cpu',
)
self.prompt_lens_cpu[:self.num_seqs].copy_(prompt_lens)
self.prompt_lens_cuda[:self.num_seqs].copy_(
self.prompt_lens_cpu[:self.num_seqs], non_blocking=True)
# number of tokens in the kv cache for each sequence in the batch
cached_token_lens = torch.tensor(
self.kv_cache_params.num_cached_tokens_per_seq,
dtype=torch.int,
device='cpu',
) if self.kv_cache_params.use_cache else None
if self.enable_flash_mla:
self.prepare_flash_mla()
# number of tokens needed in the kv cache for each sequence after the next pass
if self.helix_is_inactive_rank is not None and len(
self.helix_is_inactive_rank):
# If helix is inactive, attend to the previously cached tokens only.
assert cached_token_lens is not None, "cached_token_lens should be set for helix"
kv_lens = cached_token_lens.clone()
helix_is_inactive_rank_cpu = torch.tensor(
self.helix_is_inactive_rank,
dtype=torch.bool,
device='cpu',
)
active_rank = ~helix_is_inactive_rank_cpu
kv_lens[active_rank] += self.seq_lens_kv[active_rank]
else:
kv_lens = cached_token_lens + self.seq_lens_kv if cached_token_lens is not None else self.seq_lens_kv
# self.kv_lens is the valid kv cache length, while the self.kv_lens_cuda is
# the sequence length including the cached tokens and the input tokens.
self.kv_lens[:self.num_seqs].copy_(
kv_lens + self.kv_cache_params.num_extra_kv_tokens)
self.kv_lens_cuda[:self.num_seqs].copy_(
kv_lens[:self.num_seqs].pin_memory(), non_blocking=True)
# total kv lens for context requests and generation requests, without extra tokens
self.host_total_kv_lens[0] = kv_lens[:self.num_contexts].sum().item()
self.host_total_kv_lens[1] = kv_lens[self.num_contexts:self.
num_seqs].sum().item()
self.host_request_types[:self.num_contexts].fill_(0)
self.host_request_types[self.num_contexts:self.num_seqs].fill_(1)
# prepare for kv cache reuse/chunked context in MLA
if self.enable_context_mla_with_cached_kv:
self.prepare_context_mla_with_cached_kv(cached_token_lens, kv_lens)
# kv block offsets
assert self.request_ids is not None
if self.kv_cache_manager is not None:
# Copy blocks for all context requests
self.kv_cache_manager.impl.copy_batch_block_offsets(
self.host_kv_cache_block_offsets,
self.request_ids[:self.num_contexts], 1, 0)
# Copy blocks for all generation requests
self.kv_cache_manager.impl.copy_batch_block_offsets(
self.host_kv_cache_block_offsets,
self.request_ids[self.num_contexts:], self.beam_width,
self.num_contexts)
self.kv_cache_block_offsets[:, :self.num_seqs].copy_(
self.host_kv_cache_block_offsets[:, :self.num_seqs],
non_blocking=True)
error_message = (
f"The max KV cache length of input sequences ({self.kv_lens[:self.num_seqs].max()}) "
f"exceeds the KV cache manager's maximum supported length "
f"({self.kv_cache_manager.max_seq_len}).")
assert self.kv_lens[:self.num_seqs].max(
) <= self.kv_cache_manager.max_seq_len, error_message
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
# Don't use self.kv_lens here because it includes extra tokens.
# Use actual KV length (without extra tokens) for kv_lens_runtime,
# which becomes host_past_key_value_lengths and eventually mMaxSeqLenKv.
self.kv_lens_runtime = kv_lens[:self.num_seqs]
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
self.host_request_types_runtime = self.host_request_types[:self.
num_seqs]
def prepare_flash_mla(self) -> None:
block_ids_per_seq = self.kv_cache_manager.get_block_ids_per_seq(
self.request_ids).pin_memory()
num_blocks = block_ids_per_seq.shape[1]
self.kv_block_ids_per_seq.fill_(0)
self.kv_block_ids_per_seq[:self.num_seqs, :num_blocks].copy_(
block_ids_per_seq, non_blocking=True)
self.block_ids_per_seq.fill_(0)
self.block_ids_per_seq[:self.num_generations, :num_blocks].copy_(
block_ids_per_seq[self.num_contexts:], non_blocking=True)
def pre_process_for_chunked_prefill(
self,
chunked_seq_len: torch.Tensor,
chunked_global_offset: torch.
Tensor, # [chunked_loop_num + 1, num_contexts]
cu_chunked_seq_len: torch.Tensor,
merge_op_tensor: torch.Tensor,
max_chunk_len_per_loop: list[int],
chunked_loop_num: int,
) -> None:
"""
Pre-process the MLA layer for chunked prefill.
This method is called before the forward pass to prepare the MLA layer for chunked prefill.
"""
num_contexts = self.num_contexts
chunk_size = self.runtime_features.chunk_size
chunk_batch_size = self.runtime_features.chunked_prefill_buffer_batch_size
total_chunk_size = chunk_size * chunk_batch_size
remain_buffer_len = total_chunk_size
current_batch_idx = 0
max_chunk_len_per_loop.clear()
max_chunk_len = 0
# cal chunked_seq_len
for batch_idx in range(num_contexts):
cached_kv_len = self.kv_cache_params.num_cached_tokens_per_seq[
batch_idx]
while cached_kv_len > 0:
used_buffer_len = min(remain_buffer_len, cached_kv_len)
chunked_seq_len[current_batch_idx, batch_idx] = used_buffer_len
max_chunk_len = max(max_chunk_len, used_buffer_len)
remain_buffer_len -= used_buffer_len
cached_kv_len -= used_buffer_len
chunked_global_offset[
current_batch_idx + 1, batch_idx] = chunked_global_offset[
current_batch_idx,
batch_idx] + chunked_seq_len[current_batch_idx,
batch_idx]
if remain_buffer_len == 0:
current_batch_idx += 1
remain_buffer_len = total_chunk_size
max_chunk_len_per_loop.append(max_chunk_len)
max_chunk_len = 0
if len(max_chunk_len_per_loop) < chunked_loop_num:
max_chunk_len_per_loop.append(max_chunk_len)
assert len(
max_chunk_len_per_loop
) == chunked_loop_num, f"max_chunk_len_per_loop size {len(max_chunk_len_per_loop)} != chunked_loop_num {chunked_loop_num}"
for loop_idx in range(chunked_loop_num):
cu_chunked_seq_len[loop_idx, 0] = 0
torch.cumsum(chunked_seq_len[loop_idx, :num_contexts],
dim=0,
dtype=torch.int64,
out=cu_chunked_seq_len[loop_idx, 1:num_contexts + 1])
for s in range(num_contexts):
if chunked_seq_len[loop_idx, s] > 0 and (