From 09b012e03cca163a86c963bd8ff3809e36610af2 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Sun, 13 Jul 2025 23:22:08 +0000 Subject: [PATCH 01/38] [Draft] DeepGEMM Blackwell integration Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- examples/llm-api/quickstart_advanced.py | 2 +- requirements.txt | 1 + tensorrt_llm/_torch/modules/attention.py | 55 +-- .../_torch/modules/fused_moe/create_moe.py | 17 + .../modules/fused_moe/fused_moe_deepgemm.py | 320 ++++++++++++++++++ tensorrt_llm/_torch/modules/linear.py | 29 +- tensorrt_llm/quantization/utils/__init__.py | 4 +- tensorrt_llm/quantization/utils/fp8_utils.py | 75 ++++ tests/unittest/_torch/helpers.py | 35 ++ .../unittest/_torch/modules/test_fused_moe.py | 173 +++++++++- .../_torch/thop/test_fp8_block_scale_gemm.py | 40 ++- 11 files changed, 717 insertions(+), 34 deletions(-) create mode 100644 tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py create mode 100644 tensorrt_llm/quantization/utils/fp8_utils.py diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 5e447e6a0e4..a6397b6711b 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -50,7 +50,7 @@ def add_llm_args(parser): parser.add_argument('--moe_backend', type=str, default='CUTLASS', - choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP']) + choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM']) parser.add_argument('--enable_attention_dp', default=False, action='store_true') diff --git a/requirements.txt b/requirements.txt index 16c1e4b5f8c..4c8eee09e3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,3 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile +deep_gemm @ git+https://github.com/RayWang96/DeepGEMM.git@multi_arch_support diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 0f2a191a9c0..423f82cec3c 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -365,30 +365,37 @@ def fp8_block_scaling_bmm_out( torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, out) elif sm_version == 100: - low_latency = True - use_deep_seek_fp8 = True - tile_size = 8 - epilogue_tile_m = 64 if use_deep_seek_fp8 else 128 - m_size = mat1.shape[0] - if m_size % tile_size != 0: - tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size - mat1 = torch.nn.functional.pad( - mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0) - - mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( - mat1) - output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen( - mat1_fp8, - mat2_fp8, - tile_size=tile_size, - epilogue_tile_m=epilogue_tile_m, - use_deep_seek_fp8=use_deep_seek_fp8, - low_latency=low_latency, - dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1), - dq_sfs_b=mat2_scale, - out_dtype=out.dtype, - ) - out.copy_(output[:, :m_size]) + from ..models.modeling_deepseekv3 import weight_dequant + mat2 = weight_dequant( + mat2_fp8.view(-1, mat2_fp8.shape[-1]), + mat2_scale.view(-1, mat2_scale.shape[-1])).view(*mat2_fp8.shape) + output = torch.einsum("mbk,bnk->bmn", mat1, mat2.to(mat1.dtype)) + out.copy_(output) + + # low_latency = True + # use_deep_seek_fp8 = True + # tile_size = 8 + # epilogue_tile_m = 64 if use_deep_seek_fp8 else 128 + # m_size = mat1.shape[0] + # if m_size % tile_size != 0: + # tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size + # mat1 = torch.nn.functional.pad( + # mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0) + + # mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( + # mat1) + # output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen( + # mat1_fp8, + # mat2_fp8, + # tile_size=tile_size, + # epilogue_tile_m=epilogue_tile_m, + # use_deep_seek_fp8=use_deep_seek_fp8, + # low_latency=low_latency, + # dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1), + # dq_sfs_b=mat2_scale, + # out_dtype=out.dtype, + # ) + # out.copy_(output[:, :m_size]) else: raise NotImplementedError(f"SM{sm_version} is not supported") diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 17f7e436b17..0b47e18f60d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -8,6 +8,7 @@ from ...model_config import ModelConfig from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE +from .fused_moe_deepgemm import DeepGemmFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE from .fused_moe_wide_ep import WideEPMoE @@ -31,6 +32,8 @@ def get_moe_cls( return VanillaMoE elif moe_backend.upper() == "CUTEDSL": return CuteDslFusedMoE + elif moe_backend.upper() == "DEEPGEMM": + return DeepGemmFusedMoE elif moe_backend.upper() == "TRTLLM": if quant_config is not None and ( quant_config.quant_mode.has_fp8_block_scales() @@ -139,5 +142,19 @@ def create_moe( apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, ) + elif moe_cls == DeepGemmFusedMoE: + return moe_cls( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream=aux_stream, + weight_loading_mode=weight_loading_mode, + apply_router_weight_on_input=apply_router_weight_on_input, + layer_idx=layer_idx, + ) else: raise ValueError(f"Unsupported moe backend: {moe_cls}") diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py new file mode 100644 index 00000000000..22ff4545243 --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -0,0 +1,320 @@ +from typing import List, Optional, Union + +import deep_gemm +import torch +import torch.nn.functional as F + +import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils +from tensorrt_llm._utils import nvtx_range + +from ...model_config import ModelConfig +from ...utils import Fp4QuantizedTensor +from .fused_moe_cutlass import CutlassFusedMoE +from .quantization import MoEWeightLoadingMode +from .routing import BaseMoeRoutingMethod + + +@nvtx_range("[DG] act") +def swiglu_fused_moe(x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +@nvtx_range("[DG]") +def deepgemm_fp8_group_blockwise_gemm_ref( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + m_indices: torch.Tensor, +) -> torch.Tensor: + + # m, k = a.shape + # num_groups, n, _ = b.shape + + # m_padded = (m + 127) // 128 * 128 + torch.cuda.synchronize() + d = torch.empty((a.shape[0], b.shape[1]), + device=b.device, + dtype=torch.bfloat16) + # m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) + # for idx in range(offset_array.numel() - 1): + # m_indices[offset_array[idx]:offset_array[idx + 1]] = idx + + # for g in range(num_groups): + # aa = a[offset_array[g]:offset_array[g + 1], :].to(torch.bfloat16) + # aa_sf = a_sf[offset_array[g]:offset_array[g + 1], :] + # aa_dq = aa * aa_sf.repeat_interleave(128, dim=1)[:aa.shape[0], :aa.shape[1]] + # bb = b[g, :, :].to(torch.bfloat16) + # bb_sf = b_sf[g, :, :] + # bb_dq = bb * bb_sf.repeat_interleave(128, dim=0).repeat_interleave(128, dim=1)[:bb.shape[0], :bb.shape[1]] + # if aa_dq.numel() == 0: + # continue + # d[offset_array[g]:offset_array[g + 1], :] = (aa_dq @ bb_dq.t()) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a, a_sf), (b, b_sf), d, + m_indices) + torch.cuda.synchronize() + return d + + +class DeepGemmFusedMoE(CutlassFusedMoE): + """ + Python Flow of Fused Mixture of Experts (MoE) Layer. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of top experts to select for each input token. + hidden_size (int): Size of the hidden state. + intermediate_size (int): Size of the intermediate state. + aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + dtype (Optional[torch.dtype]): Data type for the weights. + reduce_results (bool): Whether to reduce the results across devices. + model_config (ModelConfig): Configuration object for the model. + + This backend is composed of multiple custom ops: + 1. moe_permute_op: permute the input tensor and the expert selected tensor. + 2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm. + 3. moe_finalize_scale_op: finalize the scale of the output tensor. + """ + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream: Optional[torch.cuda.Stream] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. + VANILLA, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + ): + + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream=aux_stream, + weight_loading_mode=weight_loading_mode, + apply_router_weight_on_input=apply_router_weight_on_input, + layer_idx=layer_idx, + ) + + @nvtx_range("[DG] forward") + def forward_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + if isinstance(x, Fp4QuantizedTensor): + assert output_dtype is not None + output_dtype = output_dtype + else: + output_dtype = x.dtype + + # apply routing + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits) + assert token_selected_experts.shape[ + 1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" + assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + x = x * token_final_scales.to(x.dtype) + # TODO: remove this once we have correct fusedmoe kernel ready + token_final_scales = None + + # quantize inputs + use_deepseek_fp8_block_scale = False + x_sf = None + if self.has_any_quant: + if self.has_deepseek_fp8_block_scales: + use_deepseek_fp8_block_scale = True + else: + raise ValueError( + f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" + ) + + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_experts, + token_final_scales, + None, # w3_w1_weight.view(weight_dtype), + None, # w2_weight.view(weight_dtype), + None, # quant_scales, + input_sf=x_sf, + num_experts_on_rank=self.expert_size_per_partition, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + cluster_size=self.cluster_size, + cluster_rank=self.cluster_rank, + min_latency_mode=False, + use_fp8_block_scaling=use_deepseek_fp8_block_scale, + ) + + experts = torch.arange(self.ep_rank * self.expert_size_per_partition, + (self.ep_rank + 1) * + self.expert_size_per_partition, + device=x.device).view(-1, 1, 1) + matches = (token_selected_experts == experts).cpu() + token_per_expert = matches.sum(dim=[-1, -2]).flatten() + token_per_expert_padded = (token_per_expert + 127) // 128 * 128 + token_per_expert_offset_padded = torch.cat( + (torch.tensor([0], dtype=torch.int32), + torch.cumsum(token_per_expert_padded, dim=0))) + + permuted_data_tensor = torch.empty(token_per_expert_padded.sum(), + x.shape[1], + dtype=x.dtype, + device=x.device) + m_indices = torch.empty(permuted_data_tensor.shape[0], + dtype=torch.int32) + token_map = torch.zeros(permuted_data_tensor.shape[0], + dtype=torch.int32) + m = matches.nonzero() + m_indices = torch.cat([ + torch.full((l, ), i, dtype=torch.int32) + for i, l in enumerate(token_per_expert_padded) + ]) + for idx in range(experts.numel()): + token_map[token_per_expert_offset_padded[idx]: + token_per_expert_offset_padded[idx] + + token_per_expert[idx]] = 1 + permuted_data_tensor[token_map > 0, :] = x[m[:, 1], :] + + # token_final_scales_padded = [] + # token_map = [] + # expert_first_token_offset_tensor = torch.zeros( + # self.expert_size_per_partition + 1, dtype=torch.int32) + + # t_idx = 0 + # accum_t_idx = 0 + # for e_idx in range(self.ep_rank * self.expert_size_per_partition, (self.ep_rank + 1) * self.expert_size_per_partition): + # for idx, token in enumerate(x): + # if e_idx in token_selected_experts[idx]: + # token_final_scales_padded.append( + # token_final_scales[idx][torch.where( + # token_selected_experts[idx] == e_idx)[0].item()]) + # token_map.append(idx) + # t_idx += 1 + # ceil_t_idx = (t_idx + 127) // 128 * 128 + # for _ in range(ceil_t_idx - t_idx): + # token_final_scales_padded.append(0) + # token_map.append(-1) + # t_idx = ceil_t_idx + # accum_t_idx += idx + # expert_first_token_offset_tensor[e_idx - self.ep_rank * self.expert_size_per_partition + 1] = t_idx + # # print(self.ep_rank, x.shape, expert_first_token_offset_tensor[-1]) + # # print("-------------------") + # permuted_data_tensor = torch.zeros(expert_first_token_offset_tensor[-1], x.shape[1], dtype=x.dtype, device=x.device) + # for idx, line in enumerate(permuted_data_tensor): + # token_idx = token_map[idx] + # if token_idx >= 0: + # line.copy_(x[token_idx, :]) + # if len(permuted_data_tensor) == 0: + # # for e_idx in range(self.ep_rank * self.expert_size_per_partition, (self.ep_rank + 1) * self.expert_size_per_partition): + # # for idx, token in enumerate(x): + # # if e_idx in token_selected_experts[idx]: + # # print("Yes!") + # return torch.zeros_like(x) + # # assert False + # # permuted_data_tensor = torch.stack(permuted_data_tensor).contiguous() + # token_final_scales_padded = torch.Tensor(token_final_scales_padded).contiguous() + + # print(permuted_data_tensor.shape, token_final_scales_padded.shape) + # print(permuted_data_tensor[:, 0]) + # print(x[:, 0]) + # print(token_final_scales_padded) + # print(token_final_scales) + # print(token_selected_experts) + # print(expert_first_token_offset_tensor) + # print(token_map) + + if permuted_data_tensor.numel() == 0: + return torch.zeros_like(x) + act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( + permuted_data_tensor) + # print(f"act_input_fp8, shape: {act_input_fp8.shape}, type: {act_input_fp8.dtype}") + # print(f"act_input_sf, shape: {act_input_sf.shape}, type: {act_input_sf.dtype}") + h1 = deepgemm_fp8_group_blockwise_gemm_ref( + a=act_input_fp8, + b=self.w3_w1_weight, + a_sf=act_input_sf, + b_sf=self.quant_scales[0], + m_indices=m_indices, + ) + h2 = swiglu_fused_moe(h1) + # print(f"h2, shape: {h2.shape}, type: {h2.dtype}") + act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2) + # print(f"act_input_fp8, shape: {act_input_fp8.shape}, type: {act_input_fp8.dtype}") + # print(f"act_input_sf, shape: {act_input_sf.shape}, type: {act_input_sf.dtype}") + + h3 = deepgemm_fp8_group_blockwise_gemm_ref( + a=act_input_fp8, + b=self.w2_weight, + a_sf=act_input_sf, + b_sf=self.quant_scales[1], + m_indices=m_indices, + ) + + # print(m_indices[token_map > 0]) + # for ss in [permuted_data_tensor, h1, h2, h3]: + # print("--") + # print(ss[token_map > 0, 0]) + + # print(111, m.shape, token_final_scales[m[:, 1], m[:, 2]].unsqueeze(1).shape, h3[token_map, :].shape) + res = (h3[token_map > 0, :] * + token_final_scales[m[:, 1], m[:, 2]].unsqueeze(1)).to(h3.dtype) + + final_hidden_states = torch.zeros_like(x) + indices = m[:, 1].unsqueeze(1).expand(-1, res.size(1)).cuda() # [N, D] + + # 使用scatter_add_进行累加 + # print(final_hidden_states.dtype, res.dtype) + # final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( + # h3, + # None, # biases + # token_final_scales, + # unpermuted_row_to_permuted_row_tensor, + # permuted_row_to_unpermuted_row_tensor, + # token_selected_experts, + # expert_first_token_offset_tensor, + # False, # enable_alltoall + # x.shape[0], # num_rows + # x.shape[1], # hidden_size + # self.routing_method.top_k, + # self.expert_size_per_partition, # num_experts_per_node + # self.tp_size, + # self.tp_rank, + # self.ep_size, + # self.ep_rank, + # ) + final_hidden_states.scatter_add_(0, indices, res) + # final_hidden_states = torch.zeros_like(x) + + return final_hidden_states diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 1ef5be24c8b..3dcfdea1191 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -19,7 +19,9 @@ from tensorrt_llm.quantization.functional import \ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.mode import QuantAlgo +from tensorrt_llm.quantization.utils.fp8_utils import per_token_cast_to_fp8_e8m0 +from ..._utils import get_sm_version from ...models.modeling_utils import QuantConfig from ..utils import Fp4QuantizedTensor @@ -570,10 +572,20 @@ def apply(self, module: Linear, input: torch.Tensor, input = input.to(torch.bfloat16) * module.input_scale assert input.dtype == torch.bfloat16 - act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(input) + if get_sm_version() == 100: + import deep_gemm + a_tuple = per_token_cast_to_fp8_e8m0(input) + output = torch.empty((input.shape[0], module.weight.shape[0]), + device=input.device, + dtype=torch.bfloat16) + deep_gemm.fp8_gemm_nt(a_tuple, (module.weight, module.weight_scale), + output) + else: + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( + input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( - act_input_fp8, module.weight, act_input_sf, module.weight_scale) + output = torch.ops.trtllm.fp8_block_scaling_gemm( + act_input_fp8, module.weight, act_input_sf, module.weight_scale) if bias is not None: output = output + bias return output @@ -593,6 +605,9 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size, module.tp_rank, module.tp_mode).squeeze() + # if get_sm_version == 100: + # weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, weight_scale) + # copy_weight(module.weight, weight) copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -603,7 +618,6 @@ def load_weights_fused_qkv_linear(self, module: Linear, q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) fused_weight = torch.cat((q_weight, k_weight, v_weight)) - copy_weight(module.weight, fused_weight) scale_name = self._get_scale_name(weights) q_scale = load_weight_shard(weights[0][scale_name], module.tp_size, @@ -614,6 +628,9 @@ def load_weights_fused_qkv_linear(self, module: Linear, module.tp_rank, module.tp_mode) fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze() + # if get_sm_version == 100: + # fused_weight, fused_fp8_block_scale = resmooth_to_fp8_e8m0(fused_weight, fused_fp8_block_scale) + copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_fp8_block_scale) def load_weights_fused_gate_up_linear(self, module: Linear, @@ -621,7 +638,6 @@ def load_weights_fused_gate_up_linear(self, module: Linear, gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) fused_weight = torch.cat((gate_weight, up_weight)) - copy_weight(module.weight, fused_weight) scale_name = self._get_scale_name(weights) left_scale = load_weight_shard(weights[0][scale_name], module.tp_size, @@ -629,6 +645,9 @@ def load_weights_fused_gate_up_linear(self, module: Linear, right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, module.tp_rank, module.tp_mode) fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze() + # if get_sm_version == 100: + # fused_weight, fused_scale = resmooth_to_fp8_e8m0(fused_weight, fused_scale) + copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_scale) diff --git a/tensorrt_llm/quantization/utils/__init__.py b/tensorrt_llm/quantization/utils/__init__.py index a0cc798c42f..a79df9ebcb2 100644 --- a/tensorrt_llm/quantization/utils/__init__.py +++ b/tensorrt_llm/quantization/utils/__init__.py @@ -1,3 +1,3 @@ -from . import fp4_utils +from . import fp4_utils, fp8_utils -__all__ = ['fp4_utils'] +__all__ = ['fp4_utils', 'fp8_utils'] diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py new file mode 100644 index 00000000000..9c74a5a3a45 --- /dev/null +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -0,0 +1,75 @@ +from typing import Tuple + +import torch + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n), sf + + +def per_block_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if x.dim() == 2: + m, n = x.shape + x_padded = torch.zeros((align(m, 128), align(n, 128)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2)) + else: + g, m, n = x.shape + x_padded = torch.zeros((g, align(m, 128), align(n, 128)), + dtype=x.dtype, + device=x.device) + x_padded[:, :m, :n] = x + x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(1), x_view.size(3)) + + +def resmooth_to_fp8_e8m0(weight: torch.Tensor, + sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if weight.dim() == 2: + x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( + 128, dim=1)[:weight.shape[0], :weight.shape[1]] + else: + x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( + 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] + return per_block_cast_to_fp8_e8m0(x) diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 8f4e2459f1b..5e9f2ba1a26 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -8,6 +8,14 @@ def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape @@ -33,6 +41,33 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_view.size(2)) +def per_token_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n), sf + + +def per_block_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, 128), align(n, 128)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2)) + + def calc_diff(x, y): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 367f7300b09..32adc513249 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -9,7 +9,8 @@ import pytest import torch import torch.nn as nn -from _torch.helpers import per_block_cast_to_fp8 +from _torch.helpers import (per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0, + per_token_cast_to_fp8_e8m0) from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from utils.util import (skip_neither_ada_nor_hopper_unittest, @@ -25,6 +26,8 @@ VanillaMoE, WideEPMoE) from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \ CuteDslFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ + DeepGemmFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \ AlltoallMethodType from tensorrt_llm._torch.modules.gated_mlp import GatedMLP @@ -379,6 +382,174 @@ def set_tensor_value_4(x, num_row, num_cols): x.copy_(repeated) +@skip_pre_blackwell +@pytest.mark.parametrize( + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", + product( + [torch.bfloat16], + [72], + [128, 256, 384, 512, 1024, 2048, 4096, 8192], + [2560], + [DefaultMoeRoutingMethod], + ), +) +def test_fused_moe_fp8_blockwise_deepgemm(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + mapping=None): + SEQ_LEN = seq_len + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 256 + NUM_EXPERTS = num_experts + TOP_K = 2 + + routing_method = RoutingMethodCls(top_k=TOP_K) + + mapping = mapping or Mapping() + mapping.rank = mpi_rank() + torch.cuda.set_device(mapping.rank) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + # Note: we use some special values init x and weight, otherwise the test will false positive failed. + set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE) + + x = x.cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() + + weights = {} + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + w3_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8_e8m0(w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8_e8m0(w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8_e8m0(w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + fused_moe = DeepGemmFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(quant_config=quant_config, mapping=mapping), + ) + fused_moe.cuda() + fused_moe.load_weights([weights]) + + def swiglu_fused_moe(x): + x, gate = x.chunk(2, dim=-1) + return torch.nn.functional.silu(gate) * x + + def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, + b_sf: torch.Tensor, + offset_array: torch.Tensor) -> torch.Tensor: + d = torch.empty((a.shape[0], b.shape[1]), + device=b.device, + dtype=torch.bfloat16) + m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) + for idx in range(offset_array.numel() - 1): + m_indices[offset_array[idx]:offset_array[idx + 1]] = idx + + num_groups, n, k_ = b.shape + d = torch.empty((a.shape[0], b.shape[1]), + device=b.device, + dtype=torch.bfloat16) + m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) + for idx in range(offset_array.numel() - 1): + m_indices[offset_array[idx]:offset_array[idx + 1]] = idx + + for g in range(num_groups): + aa = a[offset_array[g]:offset_array[g + 1], :].to(torch.bfloat16) + aa_sf = a_sf[offset_array[g]:offset_array[g + 1], :] + aa_dq = aa * aa_sf.repeat_interleave( + 128, dim=1)[:aa.shape[0], :aa.shape[1]] + bb = b[g, :, :].to(torch.bfloat16) + bb_sf = b_sf[g, :, :] + bb_dq = bb * bb_sf.repeat_interleave(128, dim=0).repeat_interleave( + 128, dim=1)[:bb.shape[0], :bb.shape[1]] + d[offset_array[g]:offset_array[g + 1], :] = (aa_dq @ bb_dq.t()) + return d + + token_selected_experts, token_final_scales = routing_method.apply( + router_logits) + t_idx = 0 + permuted_data_tensor = torch.empty((x.shape[0] * TOP_K, x.shape[1]), + device=x.device, + dtype=torch.bfloat16) + expert_first_token_offset_tensor = torch.zeros(NUM_EXPERTS + 1, + dtype=torch.int32) + unpermute_map = [] + scales = [] + for e_idx in range(NUM_EXPERTS): + for idx, token in enumerate(x): + for i, selected_expert in enumerate(token_selected_experts[idx]): + if e_idx == selected_expert: + permuted_data_tensor[t_idx, :] = token + unpermute_map.append(idx) + scales.append(token_final_scales[idx, i]) + t_idx += 1 + expert_first_token_offset_tensor[e_idx + 1] = t_idx + + act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0( + permuted_data_tensor) + h1 = grouped_gemm( + a=act_input_fp8, + b=fused_moe.w3_w1_weight, + a_sf=act_input_sf, + b_sf=fused_moe.quant_scales[0], + offset_array=expert_first_token_offset_tensor, + ) + h2 = swiglu_fused_moe(h1) + act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0(h2) + h3 = grouped_gemm( + a=act_input_fp8, + b=fused_moe.w2_weight, + a_sf=act_input_sf, + b_sf=fused_moe.quant_scales[1], + offset_array=expert_first_token_offset_tensor, + ) + ref_output = torch.zeros_like(x) + for token_idx, h3_token in enumerate(h3): + original_idx = unpermute_map[token_idx] + ref_output[original_idx, :] += h3_token * scales[token_idx] + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + + @skip_non_hopper_unittest @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", diff --git a/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py b/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py index 44662f648a2..0fb78c01f33 100644 --- a/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py +++ b/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py @@ -18,10 +18,48 @@ import pytest import torch -from _torch.helpers import calc_diff, per_block_cast_to_fp8 +from _torch.helpers import (calc_diff, per_block_cast_to_fp8, + per_block_cast_to_fp8_e8m0, + per_token_cast_to_fp8_e8m0) from utils.util import getSMVersion +@pytest.mark.skipif( + getSMVersion() != 100, + reason="The test is for Blackwell only. Current SM is %d." % getSMVersion(), +) +@pytest.mark.parametrize( + "k, n", + [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), + (2048, 7168), (1024, 1024)], +) +@pytest.mark.parametrize( + "m", + [7, 64, 128, 4096], +) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16], +) +def test_fp8_block_scale_deep_gemm(dtype, m, k, n): + torch.random.manual_seed(0) + a = torch.randn((m, k), device='cuda', dtype=dtype) + b = torch.randn((n, k), device='cuda', dtype=dtype) + + act_a_fp8, act_a_sf = per_token_cast_to_fp8_e8m0(a) + act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b) + + output_expected = a @ b.t() + import deep_gemm + output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]), + device=act_a_fp8.device, + dtype=torch.bfloat16) + + deep_gemm.fp8_gemm_nt((act_a_fp8, act_a_sf), (act_b_fp8, act_b_sf), output) + diff = calc_diff(output, output_expected) + assert diff < 1e-2 + + @pytest.mark.skipif( getSMVersion() != 100 and getSMVersion() != 89, reason="The test is for Blackwell and Ada only. Current SM is %d." % From ec400abc28018ed5d0ea7b223fbeaae807964f61 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Sun, 13 Jul 2025 23:37:29 +0000 Subject: [PATCH 02/38] Clean up fused_moe_deepgemm.py Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 137 +----------------- 1 file changed, 2 insertions(+), 135 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 22ff4545243..c7adb73caec 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -28,29 +28,10 @@ def deepgemm_fp8_group_blockwise_gemm_ref( b_sf: torch.Tensor, m_indices: torch.Tensor, ) -> torch.Tensor: - - # m, k = a.shape - # num_groups, n, _ = b.shape - - # m_padded = (m + 127) // 128 * 128 torch.cuda.synchronize() d = torch.empty((a.shape[0], b.shape[1]), device=b.device, dtype=torch.bfloat16) - # m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) - # for idx in range(offset_array.numel() - 1): - # m_indices[offset_array[idx]:offset_array[idx + 1]] = idx - - # for g in range(num_groups): - # aa = a[offset_array[g]:offset_array[g + 1], :].to(torch.bfloat16) - # aa_sf = a_sf[offset_array[g]:offset_array[g + 1], :] - # aa_dq = aa * aa_sf.repeat_interleave(128, dim=1)[:aa.shape[0], :aa.shape[1]] - # bb = b[g, :, :].to(torch.bfloat16) - # bb_sf = b_sf[g, :, :] - # bb_dq = bb * bb_sf.repeat_interleave(128, dim=0).repeat_interleave(128, dim=1)[:bb.shape[0], :bb.shape[1]] - # if aa_dq.numel() == 0: - # continue - # d[offset_array[g]:offset_array[g + 1], :] = (aa_dq @ bb_dq.t()) deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a, a_sf), (b, b_sf), d, m_indices) torch.cuda.synchronize() @@ -70,11 +51,6 @@ class DeepGemmFusedMoE(CutlassFusedMoE): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. - - This backend is composed of multiple custom ops: - 1. moe_permute_op: permute the input tensor and the expert selected tensor. - 2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm. - 3. moe_finalize_scale_op: finalize the scale of the output tensor. """ def __init__( @@ -141,42 +117,14 @@ def forward_chunk( token_final_scales = None # quantize inputs - use_deepseek_fp8_block_scale = False - x_sf = None if self.has_any_quant: if self.has_deepseek_fp8_block_scales: - use_deepseek_fp8_block_scale = True + pass else: raise ValueError( - f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" + f"unsupported quantization mode for DEEPGEMM backend: {self.quant_config.quant_mode}" ) - ( - permuted_row_to_unpermuted_row_tensor, - permuted_token_selected_experts_tensor, - permuted_data_tensor, - expert_first_token_offset_tensor, - permuted_token_final_scales_tensor, - unpermuted_row_to_permuted_row_tensor, - ) = torch.ops.trtllm.moe_permute_op( - x, - token_selected_experts, - token_final_scales, - None, # w3_w1_weight.view(weight_dtype), - None, # w2_weight.view(weight_dtype), - None, # quant_scales, - input_sf=x_sf, - num_experts_on_rank=self.expert_size_per_partition, - tp_size=self.tp_size, - tp_rank=self.tp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - cluster_size=self.cluster_size, - cluster_rank=self.cluster_rank, - min_latency_mode=False, - use_fp8_block_scaling=use_deepseek_fp8_block_scale, - ) - experts = torch.arange(self.ep_rank * self.expert_size_per_partition, (self.ep_rank + 1) * self.expert_size_per_partition, @@ -207,60 +155,10 @@ def forward_chunk( token_per_expert[idx]] = 1 permuted_data_tensor[token_map > 0, :] = x[m[:, 1], :] - # token_final_scales_padded = [] - # token_map = [] - # expert_first_token_offset_tensor = torch.zeros( - # self.expert_size_per_partition + 1, dtype=torch.int32) - - # t_idx = 0 - # accum_t_idx = 0 - # for e_idx in range(self.ep_rank * self.expert_size_per_partition, (self.ep_rank + 1) * self.expert_size_per_partition): - # for idx, token in enumerate(x): - # if e_idx in token_selected_experts[idx]: - # token_final_scales_padded.append( - # token_final_scales[idx][torch.where( - # token_selected_experts[idx] == e_idx)[0].item()]) - # token_map.append(idx) - # t_idx += 1 - # ceil_t_idx = (t_idx + 127) // 128 * 128 - # for _ in range(ceil_t_idx - t_idx): - # token_final_scales_padded.append(0) - # token_map.append(-1) - # t_idx = ceil_t_idx - # accum_t_idx += idx - # expert_first_token_offset_tensor[e_idx - self.ep_rank * self.expert_size_per_partition + 1] = t_idx - # # print(self.ep_rank, x.shape, expert_first_token_offset_tensor[-1]) - # # print("-------------------") - # permuted_data_tensor = torch.zeros(expert_first_token_offset_tensor[-1], x.shape[1], dtype=x.dtype, device=x.device) - # for idx, line in enumerate(permuted_data_tensor): - # token_idx = token_map[idx] - # if token_idx >= 0: - # line.copy_(x[token_idx, :]) - # if len(permuted_data_tensor) == 0: - # # for e_idx in range(self.ep_rank * self.expert_size_per_partition, (self.ep_rank + 1) * self.expert_size_per_partition): - # # for idx, token in enumerate(x): - # # if e_idx in token_selected_experts[idx]: - # # print("Yes!") - # return torch.zeros_like(x) - # # assert False - # # permuted_data_tensor = torch.stack(permuted_data_tensor).contiguous() - # token_final_scales_padded = torch.Tensor(token_final_scales_padded).contiguous() - - # print(permuted_data_tensor.shape, token_final_scales_padded.shape) - # print(permuted_data_tensor[:, 0]) - # print(x[:, 0]) - # print(token_final_scales_padded) - # print(token_final_scales) - # print(token_selected_experts) - # print(expert_first_token_offset_tensor) - # print(token_map) - if permuted_data_tensor.numel() == 0: return torch.zeros_like(x) act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( permuted_data_tensor) - # print(f"act_input_fp8, shape: {act_input_fp8.shape}, type: {act_input_fp8.dtype}") - # print(f"act_input_sf, shape: {act_input_sf.shape}, type: {act_input_sf.dtype}") h1 = deepgemm_fp8_group_blockwise_gemm_ref( a=act_input_fp8, b=self.w3_w1_weight, @@ -269,10 +167,7 @@ def forward_chunk( m_indices=m_indices, ) h2 = swiglu_fused_moe(h1) - # print(f"h2, shape: {h2.shape}, type: {h2.dtype}") act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2) - # print(f"act_input_fp8, shape: {act_input_fp8.shape}, type: {act_input_fp8.dtype}") - # print(f"act_input_sf, shape: {act_input_sf.shape}, type: {act_input_sf.dtype}") h3 = deepgemm_fp8_group_blockwise_gemm_ref( a=act_input_fp8, @@ -282,39 +177,11 @@ def forward_chunk( m_indices=m_indices, ) - # print(m_indices[token_map > 0]) - # for ss in [permuted_data_tensor, h1, h2, h3]: - # print("--") - # print(ss[token_map > 0, 0]) - - # print(111, m.shape, token_final_scales[m[:, 1], m[:, 2]].unsqueeze(1).shape, h3[token_map, :].shape) res = (h3[token_map > 0, :] * token_final_scales[m[:, 1], m[:, 2]].unsqueeze(1)).to(h3.dtype) final_hidden_states = torch.zeros_like(x) indices = m[:, 1].unsqueeze(1).expand(-1, res.size(1)).cuda() # [N, D] - - # 使用scatter_add_进行累加 - # print(final_hidden_states.dtype, res.dtype) - # final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( - # h3, - # None, # biases - # token_final_scales, - # unpermuted_row_to_permuted_row_tensor, - # permuted_row_to_unpermuted_row_tensor, - # token_selected_experts, - # expert_first_token_offset_tensor, - # False, # enable_alltoall - # x.shape[0], # num_rows - # x.shape[1], # hidden_size - # self.routing_method.top_k, - # self.expert_size_per_partition, # num_experts_per_node - # self.tp_size, - # self.tp_rank, - # self.ep_size, - # self.ep_rank, - # ) final_hidden_states.scatter_add_(0, indices, res) - # final_hidden_states = torch.zeros_like(x) return final_hidden_states From d9a85ac2d828a78b51a40117eea9794fbba9fe6d Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Tue, 15 Jul 2025 06:00:31 +0000 Subject: [PATCH 03/38] Moving permute space allocation to GPU Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/modules/fused_moe/fused_moe_deepgemm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index c7adb73caec..25d8cf99e10 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -129,11 +129,11 @@ def forward_chunk( (self.ep_rank + 1) * self.expert_size_per_partition, device=x.device).view(-1, 1, 1) - matches = (token_selected_experts == experts).cpu() + matches = (token_selected_experts == experts) token_per_expert = matches.sum(dim=[-1, -2]).flatten() token_per_expert_padded = (token_per_expert + 127) // 128 * 128 token_per_expert_offset_padded = torch.cat( - (torch.tensor([0], dtype=torch.int32), + (torch.zeros(1, dtype=torch.int32, device=x.device), torch.cumsum(token_per_expert_padded, dim=0))) permuted_data_tensor = torch.empty(token_per_expert_padded.sum(), @@ -141,9 +141,11 @@ def forward_chunk( dtype=x.dtype, device=x.device) m_indices = torch.empty(permuted_data_tensor.shape[0], - dtype=torch.int32) + dtype=torch.int32, + device=x.device) token_map = torch.zeros(permuted_data_tensor.shape[0], - dtype=torch.int32) + dtype=torch.int32, + device=x.device) m = matches.nonzero() m_indices = torch.cat([ torch.full((l, ), i, dtype=torch.int32) From 7c4045c33014460fdd531d266fdade693413922d Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 15 Jul 2025 17:10:12 -0700 Subject: [PATCH 04/38] optimize padding in deepgemm moe. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 164 +++++++++++++----- 1 file changed, 121 insertions(+), 43 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 25d8cf99e10..292a7103334 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -15,11 +15,64 @@ @nvtx_range("[DG] act") +@torch.compile(dynamic=True) def swiglu_fused_moe(x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x +@nvtx_range("[DG] indexing") +@torch.compile(dynamic=True) +def indexing(x, mask): + return x[mask > 0, :].contiguous() + + +@nvtx_range("[DG] copy after permute") +@torch.compile(dynamic=True) +def copy_after( + expert_first_token_offset_tensor, + permuted_data_tensor, + base_indices, + hidden_size, +): + token_per_expert = expert_first_token_offset_tensor[ + 1:] - expert_first_token_offset_tensor[:-1] + token_per_expert_padded = (token_per_expert + 127) // 128 * 128 + expert_first_token_offset_tensor_padded = torch.cat( + (torch.zeros(1, dtype=torch.int32, + device='cuda'), torch.cumsum(token_per_expert_padded, + dim=0))) + + token_num = token_per_expert.sum() + total_tokens_padded = token_per_expert_padded.sum() + m_indices = torch.repeat_interleave(base_indices, + token_per_expert_padded, + dim=0, + output_size=total_tokens_padded) + src_offsets = torch.repeat_interleave(expert_first_token_offset_tensor[:-1], + token_per_expert, + dim=0, + output_size=token_num) + dest_starts = torch.repeat_interleave( + expert_first_token_offset_tensor_padded[:-1], + token_per_expert, + dim=0, + output_size=token_num) + token_j_offset_in_expert = torch.arange(token_num, + device='cuda') - src_offsets + dest_indices = dest_starts + token_j_offset_in_expert + + permuted_data_tensor_padded = torch.empty(total_tokens_padded, + hidden_size, + dtype=permuted_data_tensor.dtype, + device='cuda') + src_indices = torch.arange(dest_indices.shape[0], device='cuda') + permuted_data_tensor_padded.index_copy_(0, dest_indices, + permuted_data_tensor[src_indices]) + + return permuted_data_tensor_padded, m_indices, dest_indices + + @nvtx_range("[DG]") def deepgemm_fp8_group_blockwise_gemm_ref( a: torch.Tensor, @@ -28,6 +81,7 @@ def deepgemm_fp8_group_blockwise_gemm_ref( b_sf: torch.Tensor, m_indices: torch.Tensor, ) -> torch.Tensor: + torch.cuda.synchronize() d = torch.empty((a.shape[0], b.shape[1]), device=b.device, @@ -51,6 +105,11 @@ class DeepGemmFusedMoE(CutlassFusedMoE): dtype (Optional[torch.dtype]): Data type for the weights. reduce_results (bool): Whether to reduce the results across devices. model_config (ModelConfig): Configuration object for the model. + + This backend is composed of multiple custom ops: + 1. moe_permute_op: permute the input tensor and the expert selected tensor. + 2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm. + 3. moe_finalize_scale_op: finalize the scale of the output tensor. """ def __init__( @@ -84,6 +143,10 @@ def __init__( layer_idx=layer_idx, ) + self.base_indices = torch.arange(self.expert_size_per_partition, + device="cuda", + dtype=torch.int32) + @nvtx_range("[DG] forward") def forward_chunk( self, @@ -117,50 +180,53 @@ def forward_chunk( token_final_scales = None # quantize inputs + use_deepseek_fp8_block_scale = False + x_sf = None if self.has_any_quant: if self.has_deepseek_fp8_block_scales: - pass + use_deepseek_fp8_block_scale = True else: raise ValueError( - f"unsupported quantization mode for DEEPGEMM backend: {self.quant_config.quant_mode}" + f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" ) - experts = torch.arange(self.ep_rank * self.expert_size_per_partition, - (self.ep_rank + 1) * - self.expert_size_per_partition, - device=x.device).view(-1, 1, 1) - matches = (token_selected_experts == experts) - token_per_expert = matches.sum(dim=[-1, -2]).flatten() - token_per_expert_padded = (token_per_expert + 127) // 128 * 128 - token_per_expert_offset_padded = torch.cat( - (torch.zeros(1, dtype=torch.int32, device=x.device), - torch.cumsum(token_per_expert_padded, dim=0))) - - permuted_data_tensor = torch.empty(token_per_expert_padded.sum(), - x.shape[1], - dtype=x.dtype, - device=x.device) - m_indices = torch.empty(permuted_data_tensor.shape[0], - dtype=torch.int32, - device=x.device) - token_map = torch.zeros(permuted_data_tensor.shape[0], - dtype=torch.int32, - device=x.device) - m = matches.nonzero() - m_indices = torch.cat([ - torch.full((l, ), i, dtype=torch.int32) - for i, l in enumerate(token_per_expert_padded) - ]) - for idx in range(experts.numel()): - token_map[token_per_expert_offset_padded[idx]: - token_per_expert_offset_padded[idx] + - token_per_expert[idx]] = 1 - permuted_data_tensor[token_map > 0, :] = x[m[:, 1], :] - - if permuted_data_tensor.numel() == 0: + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_experts, + token_final_scales, + None, # w3_w1_weight.view(weight_dtype), + None, # w2_weight.view(weight_dtype), + None, # quant_scales, + input_sf=x_sf, + num_experts_on_rank=self.expert_size_per_partition, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + cluster_size=self.cluster_size, + cluster_rank=self.cluster_rank, + min_latency_mode=False, + use_fp8_block_scaling=use_deepseek_fp8_block_scale, + ) + + permuted_data_tensor_padded, m_indices, dest_indices = copy_after( + expert_first_token_offset_tensor, + permuted_data_tensor, + self.base_indices, + self.hidden_size, + ) + + if permuted_data_tensor_padded.numel() == 0: return torch.zeros_like(x) act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( - permuted_data_tensor) + permuted_data_tensor_padded) h1 = deepgemm_fp8_group_blockwise_gemm_ref( a=act_input_fp8, b=self.w3_w1_weight, @@ -170,7 +236,6 @@ def forward_chunk( ) h2 = swiglu_fused_moe(h1) act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2) - h3 = deepgemm_fp8_group_blockwise_gemm_ref( a=act_input_fp8, b=self.w2_weight, @@ -179,11 +244,24 @@ def forward_chunk( m_indices=m_indices, ) - res = (h3[token_map > 0, :] * - token_final_scales[m[:, 1], m[:, 2]].unsqueeze(1)).to(h3.dtype) - - final_hidden_states = torch.zeros_like(x) - indices = m[:, 1].unsqueeze(1).expand(-1, res.size(1)).cuda() # [N, D] - final_hidden_states.scatter_add_(0, indices, res) + permuted_data_tensor[0:dest_indices.shape[0]].copy_(h3[dest_indices]) + final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( + permuted_data_tensor, + None, # biases + token_final_scales, + unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row_tensor, + token_selected_experts, + expert_first_token_offset_tensor, + False, # enable_alltoall + x.shape[0], # num_rows + x.shape[1], # hidden_size + self.routing_method.top_k, + self.expert_size_per_partition, # num_experts_per_node + self.tp_size, + self.tp_rank, + self.ep_size, + self.ep_rank, + ) return final_hidden_states From 20b25928699300faff4179ea281f3eefd797369f Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 15 Jul 2025 23:17:40 -0700 Subject: [PATCH 05/38] add torch compile to per_token_cast_to_fp8_e8m0 and rm the two sync. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py | 2 -- tensorrt_llm/quantization/utils/fp8_utils.py | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 292a7103334..e3bb244eab5 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -82,13 +82,11 @@ def deepgemm_fp8_group_blockwise_gemm_ref( m_indices: torch.Tensor, ) -> torch.Tensor: - torch.cuda.synchronize() d = torch.empty((a.shape[0], b.shape[1]), device=b.device, dtype=torch.bfloat16) deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a, a_sf), (b, b_sf), d, m_indices) - torch.cuda.synchronize() return d diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 9c74a5a3a45..7359b8dd0dd 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -2,6 +2,8 @@ import torch +from tensorrt_llm._utils import nvtx_range + def ceil_div(x: int, y: int) -> int: """ @@ -25,6 +27,8 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) +@nvtx_range("[DG] quantization") +@torch.compile(dynamic=True) def per_token_cast_to_fp8_e8m0( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 From c74a31a187c774114128bffc771046ca02c48d33 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 16 Jul 2025 06:22:49 +0000 Subject: [PATCH 06/38] Improve bmm. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- examples/llm-api/quickstart_advanced.py | 9 ++-- .../_torch/models/modeling_deepseekv3.py | 25 ++++++++++- tensorrt_llm/_torch/modules/attention.py | 45 ++++++++++++++----- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index a6397b6711b..0aaf9e5200f 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -47,10 +47,11 @@ def add_llm_args(parser): 'VANILLA', 'TRTLLM', 'FLASHINFER', 'FLASHINFER_STAR_ATTENTION' ]) - parser.add_argument('--moe_backend', - type=str, - default='CUTLASS', - choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM']) + parser.add_argument( + '--moe_backend', + type=str, + default='CUTLASS', + choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM', 'CUTEDSL']) parser.add_argument('--enable_attention_dp', default=False, action='store_true') diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index c8523deea2e..b9fcf95935b 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1197,7 +1197,7 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str, weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size local_num_heads = num_heads // weight_divisor - k_nope_weight_trans = k_nope_weight.transpose(2, 1) + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() kv_b_proj = torch.concat([ k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, @@ -1243,7 +1243,7 @@ def load_kv_b_proj_and_k_b_proj_trans_dequant( weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size local_num_heads = num_heads // weight_divisor - k_nope_weight_trans = k_nope_weight.transpose(2, 1) + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() kv_b_proj = torch.concat([ k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, @@ -1337,6 +1337,27 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, attn_module.v_b_proj_scale = nn.Parameter( v_b_proj_scale, requires_grad=False) + if attn_module.k_b_proj_trans_dequant is not None: + attn_module.k_b_proj_trans_dequant.data.copy_( + weight_dequant( + k_b_proj_trans.view( + -1, k_b_proj_trans.shape[-1]).cuda(), + k_b_proj_trans_scale.view( + -1, + k_b_proj_trans_scale.shape[-1]).cuda(), + ).view( + *attn_module.k_b_proj_trans_dequant.shape). + to(attn_module.k_b_proj_trans_dequant.dtype)) + if attn_module.v_b_proj_dequant is not None: + attn_module.v_b_proj_dequant.data.copy_( + weight_dequant( + v_b_proj.view(-1, + v_b_proj.shape[-1]).cuda(), + v_b_proj_scale.view( + -1, v_b_proj_scale.shape[-1]).cuda(), + ).view(*attn_module.v_b_proj_dequant.shape).to( + attn_module.v_b_proj_dequant.dtype)) + elif names[-1] == "fused_a": fused_a = weights[ f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 423f82cec3c..6820d076a84 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -357,6 +357,7 @@ def fp8_block_scaling_bmm_out( mat2_fp8: torch.Tensor, mat2_scale: torch.Tensor, out: torch.Tensor, + mat2_dequant: Optional[torch.Tensor] = None, ) -> torch.Tensor: sm_version = get_sm_version() if sm_version == 90 or sm_version == 89: @@ -365,11 +366,7 @@ def fp8_block_scaling_bmm_out( torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, out) elif sm_version == 100: - from ..models.modeling_deepseekv3 import weight_dequant - mat2 = weight_dequant( - mat2_fp8.view(-1, mat2_fp8.shape[-1]), - mat2_scale.view(-1, mat2_scale.shape[-1])).view(*mat2_fp8.shape) - output = torch.einsum("mbk,bnk->bmn", mat1, mat2.to(mat1.dtype)) + output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2)) out.copy_(output) # low_latency = True @@ -683,6 +680,8 @@ def create_weights(self): requires_grad=False, ) + self.k_b_proj_trans_dequant = None + self.v_b_proj_dequant = None if has_fp8_block_scales: self.k_b_proj_trans_scale = nn.Parameter( torch.empty( @@ -708,6 +707,23 @@ def create_weights(self): ), requires_grad=False, ) + if get_sm_version() == 100: + assert self.dtype == torch.bfloat16 + self.k_b_proj_trans_dequant = nn.Parameter( + torch.empty( + (self.num_heads, self.kv_lora_rank, + self.qk_nope_head_dim), + dtype=self.dtype, + ), + requires_grad=False, + ) + self.v_b_proj_dequant = nn.Parameter( + torch.empty( + (self.num_heads, self.v_head_dim, self.kv_lora_rank), + dtype=self.dtype, + ), + requires_grad=False, + ) else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None @@ -1203,8 +1219,13 @@ def forward_generation( # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) - fp8_block_scaling_bmm_out(q_nope, self.k_b_proj_trans, - self.k_b_proj_trans_scale, q_nope_out) + fp8_block_scaling_bmm_out( + q_nope, + self.k_b_proj_trans, + self.k_b_proj_trans_scale, + q_nope_out, + self.k_b_proj_trans_dequant, + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") @@ -1253,9 +1274,13 @@ def forward_generation( self.v_b_proj.transpose(1, 2), attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: - fp8_block_scaling_bmm_out(attn_out_latent, self.v_b_proj, - self.v_b_proj_scale, - attn_output.transpose(0, 1)) + fp8_block_scaling_bmm_out( + attn_out_latent, + self.v_b_proj, + self.v_b_proj_scale, + attn_output.transpose(0, 1), + self.v_b_proj_dequant, + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.") From d3e1797f1daf43f25f0d0f4aeb6edfaf152df866 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:40:35 +0800 Subject: [PATCH 07/38] Online resmooth for fp8 checkpoint on Blackwell. (#2) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 16 +++++++++++----- tensorrt_llm/_torch/modules/linear.py | 7 ------- tensorrt_llm/quantization/utils/fp8_utils.py | 2 ++ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index b9fcf95935b..8939600c0ed 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -44,6 +44,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.utils.fp8_utils import resmooth_to_fp8_e8m0 from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -1209,11 +1210,6 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str, return kv_b_proj, k_nope_weight_trans - def check_weight_dtype(module_name: str, dtype): - weight_name = "weight" - w_dtype = weights[f"{module_name}.{weight_name}"].dtype - return w_dtype == dtype - def load_kv_b_proj_and_k_b_proj_trans_dequant( module_name: str) -> torch.Tensor: weight_name = "weight" @@ -1286,6 +1282,16 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.named_modules()) + if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( + ) and get_sm_version() == 100: + for name in list(weights.keys()): + if name.endswith("weight_scale_inv"): + weight_name = name.replace("weight_scale_inv", "weight") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): if len(module._parameters) > 0: diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 3dcfdea1191..9b98d6df9b4 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -605,9 +605,6 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size, module.tp_rank, module.tp_mode).squeeze() - # if get_sm_version == 100: - # weight, weight_scale = resmooth_to_fp8_e8m0(module.weight, weight_scale) - # copy_weight(module.weight, weight) copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -628,8 +625,6 @@ def load_weights_fused_qkv_linear(self, module: Linear, module.tp_rank, module.tp_mode) fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze() - # if get_sm_version == 100: - # fused_weight, fused_fp8_block_scale = resmooth_to_fp8_e8m0(fused_weight, fused_fp8_block_scale) copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_fp8_block_scale) @@ -645,8 +640,6 @@ def load_weights_fused_gate_up_linear(self, module: Linear, right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, module.tp_rank, module.tp_mode) fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze() - # if get_sm_version == 100: - # fused_weight, fused_scale = resmooth_to_fp8_e8m0(fused_weight, fused_scale) copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_scale) diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 7359b8dd0dd..a0f2dd4b4b7 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -70,6 +70,8 @@ def per_block_cast_to_fp8_e8m0( def resmooth_to_fp8_e8m0(weight: torch.Tensor, sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weight = weight.cuda() + sf = sf.cuda() if weight.dim() == 2: x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( 128, dim=1)[:weight.shape[0], :weight.shape[1]] From d83cc2533a47be061b4f1f96e662ea60fc1cb9ec Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Thu, 17 Jul 2025 16:21:09 +0800 Subject: [PATCH 08/38] Fix OOM issue for fp8 resmooth. (#4) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 8 +++++++- .../_torch/modules/fused_moe/quantization.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 8939600c0ed..a3d0dbe929c 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -38,6 +38,7 @@ from tqdm import tqdm from transformers import PretrainedConfig +from tensorrt_llm import logger from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType @@ -1285,12 +1286,17 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( ) and get_sm_version() == 100: for name in list(weights.keys()): - if name.endswith("weight_scale_inv"): + # Use ".experts." to exclude shared_experts. + if name.endswith( + "weight_scale_inv") and ".experts." not in name: weight_name = name.replace("weight_scale_inv", "weight") + logger.debug(f"Resmoothing {weight_name}") weight = weights[weight_name][:] scale = weights[name][:] weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( weight, scale) + weights[weight_name] = weights[weight_name].cpu() + weights[name] = weights[name].cpu() for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index f957712e3e5..508bb7eba45 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -4,10 +4,12 @@ import torch from torch import nn +from tensorrt_llm import logger from tensorrt_llm._utils import get_sm_version from tensorrt_llm.quantization.utils.fp4_utils import ( float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices) +from tensorrt_llm.quantization.utils.fp8_utils import resmooth_to_fp8_e8m0 from ..linear import TensorParallelMode, load_weight_shard from .interface import MoEWeightLoadingMode @@ -463,6 +465,20 @@ def create_weights(self, module: torch.nn.Module): self.setup_quant_scales(module) + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode): + + if get_sm_version() == 100: + for name in list(weights.keys()): + if name.endswith("weight_scale_inv"): + weight_name = name.replace("weight_scale_inv", "weight") + logger.debug(f"Resmoothing {weight_name}") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + super().load_weights(module, weights, weight_loading_mode) + def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( fc_weight_scales=module.w3_w1_weight_scaling_factor, From e1e96fd6475740ecf320c99fa0992cf84dcd3bf6 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:03:48 +0800 Subject: [PATCH 09/38] Enbale masked grouped GEMM (#5) * add triton masked index copy for deepgemm moe. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * rm slice. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * Enable masked grouped GEMM Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --------- Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 206 ++++++++++++------ tensorrt_llm/quantization/utils/fp8_utils.py | 26 ++- 2 files changed, 160 insertions(+), 72 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index e3bb244eab5..1b57fa19ef6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -3,6 +3,8 @@ import deep_gemm import torch import torch.nn.functional as F +import triton +import triton.language as tl import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._utils import nvtx_range @@ -14,6 +16,105 @@ from .routing import BaseMoeRoutingMethod +@triton.jit +def masked_index_copy_kernel(output_ptr, input_ptr, masked_m_ptr, + start_offsets_ptr, col_size, dim_size, + BLOCK_SIZE: tl.constexpr): + # get program id and block offset + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # compute mask and pointers + token_idx = offsets // dim_size + row_idx = token_idx // col_size + col_idx = token_idx % col_size + elem_idx = offsets % dim_size + num_cols = tl.load(masked_m_ptr + row_idx) + valid = col_idx < num_cols + + # load start offset and input data + start_offset = tl.load(start_offsets_ptr + row_idx) + input_offset = (start_offset + col_idx) * dim_size + elem_idx + input = tl.load(input_ptr + input_offset, mask=valid) + + # write output + output_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx + tl.store(output_ptr + output_offsets, input, mask=valid) + + +def triton_masked_index_copy(output, input, masked_m, start_offsets): + assert output.dtype == input.dtype, "Output and input must have the same dtype" + assert output.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" + assert input.ndim == 2, "Input must be a 2D tensor" + + row_size = output.shape[0] + col_size = output.shape[1] + dim_size = output.shape[2] + total_elems = row_size * col_size * dim_size + + # launch kernel + grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + masked_index_copy_kernel[grid](output, + input, + masked_m, + start_offsets, + col_size, + dim_size, + BLOCK_SIZE=1024) + return output + + +@triton.jit +def masked_index_gather_kernel(output_ptr, input_ptr, masked_m_ptr, + start_offsets_ptr, col_size, dim_size, + BLOCK_SIZE: tl.constexpr): + # get program id and block offset + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # compute mask and pointers + token_idx = offsets // dim_size + row_idx = token_idx // col_size + col_idx = token_idx % col_size + elem_idx = offsets % dim_size + num_cols = tl.load(masked_m_ptr + row_idx) + valid = col_idx < num_cols + + # input data + input_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx + input_vals = tl.load(input_ptr + input_offsets, mask=valid) + + # get gather indices and store to output + start_offset = tl.load(start_offsets_ptr + row_idx) + gather_offset = (start_offset + col_idx) * dim_size + elem_idx + tl.store(output_ptr + gather_offset, input_vals, mask=valid) + + +@torch.no_grad() +def triton_masked_index_gather(output, input, masked_m, start_offsets): + assert output.ndim == 2, "Output must be a 2D tensor" + assert input.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" + assert masked_m.ndim == 1, "Indices must be a 1D tensor" + + row_size = input.shape[0] + col_size = input.shape[1] + dim_size = input.shape[2] + total_elems = row_size * col_size * dim_size + + # launch kernel + grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) + masked_index_gather_kernel[grid](output, + input, + masked_m, + start_offsets, + col_size, + dim_size, + BLOCK_SIZE=1024) + return output + + @nvtx_range("[DG] act") @torch.compile(dynamic=True) def swiglu_fused_moe(x): @@ -29,64 +130,31 @@ def indexing(x, mask): @nvtx_range("[DG] copy after permute") @torch.compile(dynamic=True) -def copy_after( - expert_first_token_offset_tensor, - permuted_data_tensor, - base_indices, - hidden_size, -): - token_per_expert = expert_first_token_offset_tensor[ +def preprocess_after_permute(expert_first_token_offset_tensor, ): + # get tokens per expert + masked_m = expert_first_token_offset_tensor[ 1:] - expert_first_token_offset_tensor[:-1] - token_per_expert_padded = (token_per_expert + 127) // 128 * 128 - expert_first_token_offset_tensor_padded = torch.cat( - (torch.zeros(1, dtype=torch.int32, - device='cuda'), torch.cumsum(token_per_expert_padded, - dim=0))) - - token_num = token_per_expert.sum() - total_tokens_padded = token_per_expert_padded.sum() - m_indices = torch.repeat_interleave(base_indices, - token_per_expert_padded, - dim=0, - output_size=total_tokens_padded) - src_offsets = torch.repeat_interleave(expert_first_token_offset_tensor[:-1], - token_per_expert, - dim=0, - output_size=token_num) - dest_starts = torch.repeat_interleave( - expert_first_token_offset_tensor_padded[:-1], - token_per_expert, - dim=0, - output_size=token_num) - token_j_offset_in_expert = torch.arange(token_num, - device='cuda') - src_offsets - dest_indices = dest_starts + token_j_offset_in_expert - - permuted_data_tensor_padded = torch.empty(total_tokens_padded, - hidden_size, - dtype=permuted_data_tensor.dtype, - device='cuda') - src_indices = torch.arange(dest_indices.shape[0], device='cuda') - permuted_data_tensor_padded.index_copy_(0, dest_indices, - permuted_data_tensor[src_indices]) - - return permuted_data_tensor_padded, m_indices, dest_indices + masked_m_shift = torch.zeros_like(masked_m) + masked_m_shift[1:] = masked_m[:-1] + start_offsets = torch.cumsum(masked_m_shift, dim=0) + return masked_m.to(torch.int32), start_offsets @nvtx_range("[DG]") -def deepgemm_fp8_group_blockwise_gemm_ref( +def deepgemm_fp8_group_blockwise_gemm( a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, b_sf: torch.Tensor, - m_indices: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, ) -> torch.Tensor: - d = torch.empty((a.shape[0], b.shape[1]), + d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), device=b.device, dtype=torch.bfloat16) - deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a, a_sf), (b, b_sf), d, - m_indices) + deep_gemm.fp8_m_grouped_gemm_nt_masked((a, a_sf), (b, b_sf), d, + masked_m, expected_m) return d @@ -141,10 +209,6 @@ def __init__( layer_idx=layer_idx, ) - self.base_indices = torch.arange(self.expert_size_per_partition, - device="cuda", - dtype=torch.int32) - @nvtx_range("[DG] forward") def forward_chunk( self, @@ -214,35 +278,53 @@ def forward_chunk( use_fp8_block_scaling=use_deepseek_fp8_block_scale, ) - permuted_data_tensor_padded, m_indices, dest_indices = copy_after( - expert_first_token_offset_tensor, - permuted_data_tensor, - self.base_indices, - self.hidden_size, + if permuted_data_tensor.numel() == 0: + return torch.zeros_like(x) + + max_padded_tokens = (x.shape[0] + 128) // 128 * 128 + permuted_data_tensor_padded = torch.empty( + (self.expert_size_per_partition, max_padded_tokens, + self.hidden_size), + dtype=self.dtype, + device='cuda') + + masked_m, start_offsets = preprocess_after_permute( + expert_first_token_offset_tensor ) + m_max = (x.shape[0] + 127) // 128 * 128 + expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition + permuted_data_tensor_padded = torch.empty(self.expert_size_per_partition, + m_max, + self.hidden_size, + dtype=self.dtype, + device='cuda') + triton_masked_index_copy(permuted_data_tensor_padded, + permuted_data_tensor, masked_m, start_offsets) - if permuted_data_tensor_padded.numel() == 0: - return torch.zeros_like(x) act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( permuted_data_tensor_padded) - h1 = deepgemm_fp8_group_blockwise_gemm_ref( + h1 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w3_w1_weight, a_sf=act_input_sf, b_sf=self.quant_scales[0], - m_indices=m_indices, + masked_m=masked_m, + expected_m=expected_m, ) h2 = swiglu_fused_moe(h1) act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2) - h3 = deepgemm_fp8_group_blockwise_gemm_ref( + h3 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w2_weight, a_sf=act_input_sf, b_sf=self.quant_scales[1], - m_indices=m_indices, + masked_m=masked_m, + expected_m=expected_m, ) - permuted_data_tensor[0:dest_indices.shape[0]].copy_(h3[dest_indices]) + triton_masked_index_gather(permuted_data_tensor, h3, masked_m, + start_offsets) + final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( permuted_data_tensor, None, # biases diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index a0f2dd4b4b7..8f338ec0b74 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -28,16 +28,24 @@ def ceil_to_ue8m0(x: torch.Tensor): @nvtx_range("[DG] quantization") -@torch.compile(dynamic=True) def per_token_cast_to_fp8_e8m0( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - sf = ceil_to_ue8m0(x_amax / 448.0) - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n), sf + if x.dim() == 2: + assert x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n), sf + else: + assert x.size(2) % 128 == 0 + g, m, n = x.shape + x_view = x.view(g, m, -1, 128) + x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view( + g, m, n), sf def per_block_cast_to_fp8_e8m0( @@ -70,8 +78,6 @@ def per_block_cast_to_fp8_e8m0( def resmooth_to_fp8_e8m0(weight: torch.Tensor, sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - weight = weight.cuda() - sf = sf.cuda() if weight.dim() == 2: x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( 128, dim=1)[:weight.shape[0], :weight.shape[1]] From 09b0465467a9974173c3207fe40c9f2393795304 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:51:56 +0800 Subject: [PATCH 10/38] Pin DeepGEMM's version to commit cc416ee. (#6) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- requirements.txt | 2 +- tensorrt_llm/quantization/utils/fp8_utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4c8eee09e3e..efb10ecf159 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/RayWang96/DeepGEMM.git@multi_arch_support +deep_gemm @ git+https://github.com/RayWang96/DeepGEMM.git@cc416ee diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 8f338ec0b74..8d31e4a27e9 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -78,6 +78,8 @@ def per_block_cast_to_fp8_e8m0( def resmooth_to_fp8_e8m0(weight: torch.Tensor, sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weight = weight.cuda() + sf = sf.cuda() if weight.dim() == 2: x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( 128, dim=1)[:weight.shape[0], :weight.shape[1]] From 35b4e23fb69787d0075eef33ef551317e3f5993a Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Fri, 18 Jul 2025 15:57:30 +0800 Subject: [PATCH 11/38] Improve resmooth. (#7) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py | 4 +++- tensorrt_llm/_torch/modules/fused_moe/quantization.py | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 1277e25b42f..4b0c4594ad5 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -12,7 +12,7 @@ BaseWeightLoader from tensorrt_llm._torch.models.modeling_utils import ( register_checkpoint_weight_loader, run_concurrently) -from tensorrt_llm._utils import local_mpi_rank, local_mpi_size +from tensorrt_llm._utils import local_mpi_rank, local_mpi_size, mpi_barrier from tensorrt_llm.logger import logger @@ -121,3 +121,5 @@ def prefetch_files(self, file_names: List[str]): len(local_file_names)) with multiprocessing.Pool(processes=max_processes) as pool: pool.map(self._prefetch_one_file, local_file_names) + # Ensure that all ranks have finished prefetching before loading weights + mpi_barrier() diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 508bb7eba45..d66306f48f4 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -469,8 +469,14 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight_loading_mode: MoEWeightLoadingMode): if get_sm_version() == 100: + expert_ids = set(module.initial_local_expert_ids) + if self.need_load_shared_weights(module): + expert_ids.update( + module.layer_load_balancer.get_load_expert_ids()) for name in list(weights.keys()): if name.endswith("weight_scale_inv"): + if int(name.split(".")[0]) not in expert_ids: + continue weight_name = name.replace("weight_scale_inv", "weight") logger.debug(f"Resmoothing {weight_name}") weight = weights[weight_name][:] From dce291f155d2cc3ab03b1da819376e3f133eb359 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Fri, 18 Jul 2025 16:37:10 +0800 Subject: [PATCH 12/38] Add compile for quantization kernels (#8) Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/quantization/utils/fp8_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 8d31e4a27e9..41f4314822f 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -28,6 +28,7 @@ def ceil_to_ue8m0(x: torch.Tensor): @nvtx_range("[DG] quantization") +@torch.compile(dynamic=True) def per_token_cast_to_fp8_e8m0( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if x.dim() == 2: From b3ab47d5bee6d2c813d0e2a6d0687de35e54a508 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:05:01 +0800 Subject: [PATCH 13/38] Move SF transform to TRTLLM (#11) Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 82 ++++++++--- tensorrt_llm/quantization/utils/fp8_utils.py | 128 +++++++++++++++++- 2 files changed, 191 insertions(+), 19 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 1b57fa19ef6..8299782c257 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,10 +1,12 @@ +import functools from typing import List, Optional, Union -import deep_gemm import torch import torch.nn.functional as F import triton import triton.language as tl +from deep_gemm.jit_kernels.impls import sm100_fp8_gemm_1d1d +from deep_gemm.utils.layout import MajorTypeAB import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._utils import nvtx_range @@ -144,8 +146,8 @@ def preprocess_after_permute(expert_first_token_offset_tensor, ): def deepgemm_fp8_group_blockwise_gemm( a: torch.Tensor, b: torch.Tensor, - a_sf: torch.Tensor, - b_sf: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: @@ -153,8 +155,50 @@ def deepgemm_fp8_group_blockwise_gemm( d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), device=b.device, dtype=torch.bfloat16) - deep_gemm.fp8_m_grouped_gemm_nt_masked((a, a_sf), (b, b_sf), d, - masked_m, expected_m) + compiled_dims = 'nk' + + # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` + assert a.stride(-1) == 1 + assert b.stride(-1) == 1 + assert masked_m.is_contiguous() + + num_groups, m, k = a.shape + num_groups_, n, k_ = b.shape + num_groups__, m_, n_ = d.shape + num_groups___ = masked_m.numel() + + # Type and shape checks + assert num_groups == num_groups_ == num_groups__ == num_groups___ + assert m == m_ and n == n_ and k == k_ + assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + assert d.dtype == torch.bfloat16 + assert masked_m.dtype == torch.int32 + + # D must be N-major + assert d.stride(-1) == 1 + + # Transform SFA and SFB into compute-required layout + recipe = (1, 128, 128) + sfa = fp8_utils.transform_sf_into_required_layout(sfa, + mn=m, + k=k, + recipe=recipe, + num_groups=num_groups, + is_sfa=True) + sfb = fp8_utils.transform_sf_into_required_layout(sfb, + mn=n, + k=k, + recipe=recipe, + num_groups=num_groups, + is_sfa=False) + + impl = functools.partial(sm100_fp8_gemm_1d1d.fp8_m_grouped_gemm_nt_masked, + major_a=MajorTypeAB.KMajor, + major_b=MajorTypeAB.KMajor, + compiled_dims=compiled_dims) + impl(a, sfa, b, sfb, d, masked_m, expected_m) return d @@ -289,15 +333,17 @@ def forward_chunk( device='cuda') masked_m, start_offsets = preprocess_after_permute( - expert_first_token_offset_tensor - ) + expert_first_token_offset_tensor) m_max = (x.shape[0] + 127) // 128 * 128 - expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition - permuted_data_tensor_padded = torch.empty(self.expert_size_per_partition, - m_max, - self.hidden_size, - dtype=self.dtype, - device='cuda') + expected_m = (token_selected_experts.numel() + + self.expert_size_per_partition - + 1) // self.expert_size_per_partition + permuted_data_tensor_padded = torch.empty( + self.expert_size_per_partition, + m_max, + self.hidden_size, + dtype=self.dtype, + device='cuda') triton_masked_index_copy(permuted_data_tensor_padded, permuted_data_tensor, masked_m, start_offsets) @@ -306,8 +352,8 @@ def forward_chunk( h1 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w3_w1_weight, - a_sf=act_input_sf, - b_sf=self.quant_scales[0], + sfa=act_input_sf, + sfb=self.quant_scales[0], masked_m=masked_m, expected_m=expected_m, ) @@ -316,14 +362,14 @@ def forward_chunk( h3 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w2_weight, - a_sf=act_input_sf, - b_sf=self.quant_scales[1], + sfa=act_input_sf, + sfb=self.quant_scales[1], masked_m=masked_m, expected_m=expected_m, ) triton_masked_index_gather(permuted_data_tensor, h3, masked_m, - start_offsets) + start_offsets) final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( permuted_data_tensor, diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 41f4314822f..5d277c8b828 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch @@ -88,3 +88,129 @@ def resmooth_to_fp8_e8m0(weight: torch.Tensor, x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] return per_block_cast_to_fp8_e8m0(x) + + +def get_m_alignment_for_contiguous_layout(): + return 128 + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return align(x, alignment) + + +def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), + device=x.device, + dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, + aligned_k // 4) + + # Finally, transpose + transposed = torch.transpose( + torch.empty((b, aligned_k // 4, aligned_mn), + device=x.device, + dtype=torch.int), 1, 2) + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def check_sf_layout(sf: torch.Tensor, + mn: int, + k: int, + gran: Tuple[int, int], + num_groups: Optional[int], + tma_stride_check: bool = False, + type_check: Optional[torch.dtype] = None) -> torch.Tensor: + # Type check + if type_check is not None: + assert sf.dtype == type_check + + # Always do shape checks + assert sf.dtype in (torch.float, torch.int) + assert sf.dim() == int(num_groups is not None) + 2 + if num_groups is not None: + assert sf.size(-3) == num_groups + assert sf.size(-2) == ceil_div(mn, gran[0]) + assert sf.size(-1) == ceil_div( + k, gran[1] * (1 if sf.dtype == torch.float else 4)) + + # TMA stride checks: TMA aligned and MN-major + if tma_stride_check: + if num_groups is not None: + assert sf.stride(-3) == sf.stride(-1) * sf.size(-1) + assert sf.stride(-2) == 1 + assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()) + + return sf + + +@nvtx_range("[DG] transform_sf_into_required_layout") +@torch.compile(dynamic=True) +def transform_sf_into_required_layout(sf: torch.Tensor, + mn: int, + k: int, + recipe: Tuple[int, int, int], + num_groups: Optional[int] = None, + is_sfa: bool = False): + gran = (recipe[0 if is_sfa else 1], recipe[2]) + + should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128)) + or (sf.dtype == torch.int and gran == (128, 128))) + + if not should_skip_transform: + # Pre-transform checks + check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups) + + # (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if sf.dtype == torch.float and gran == (1, 128): + sf = get_col_major_tma_aligned_packed_tensor(sf) + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + # (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if sf.dtype == torch.float and gran == (128, 128): + sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) + sf = get_col_major_tma_aligned_packed_tensor(sf) + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + if should_skip_transform: + # TODO: add transpose kernel if SF layout is not satisfied + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + assert False, f'Unknown cases: {sf.dtype=}, {gran=}' From 65d05d6de88c0a889b75a98687ab64c6b3153951 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 21 Jul 2025 11:09:17 +0800 Subject: [PATCH 14/38] Use local barrier to avoid multi-node hang issue. (#12) * Use local barrier to avoid multi-node hang issue. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> * Fix hang issue in the single-node case. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --------- Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/models/checkpoints/hf/weight_loader.py | 5 ++++- tensorrt_llm/_utils.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 4b0c4594ad5..2b90996cb6f 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -12,7 +12,8 @@ BaseWeightLoader from tensorrt_llm._torch.models.modeling_utils import ( register_checkpoint_weight_loader, run_concurrently) -from tensorrt_llm._utils import local_mpi_rank, local_mpi_size, mpi_barrier +from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank, + local_mpi_size) from tensorrt_llm.logger import logger @@ -38,6 +39,8 @@ def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files." ) self.prefetch_files(weight_files) + # Ensure that all local ranks have finished prefetching before loading weights + local_mpi_barrier() return self._load_weights_in_parallel( weight_files, self._load_safetensors_file, diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index b07430224af..e733a0331f6 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -470,6 +470,10 @@ def mpi_comm(): local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST) +def local_mpi_comm(): + return local_comm + + def mpi_rank(): return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0 @@ -508,6 +512,11 @@ def mpi_barrier(): mpi_comm().Barrier() +def local_mpi_barrier(): + if ENABLE_MULTI_DEVICE: + local_comm.Barrier() + + def mpi_broadcast(obj, root=0): return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj From d65bdac4d112282d46a727d1c9abe69e4f35c62d Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 21 Jul 2025 11:21:32 +0800 Subject: [PATCH 15/38] optimize the masked index copy and index gather (#13) * optimize the masked index copy and index gather. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> * rm torch.compile for preprocess_after_permute duo to the compatibility issue. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --------- Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 90 ++++++++++--------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 8299782c257..964e51e50e0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -19,8 +19,8 @@ @triton.jit -def masked_index_copy_kernel(output_ptr, input_ptr, masked_m_ptr, - start_offsets_ptr, col_size, dim_size, +def masked_index_copy_kernel(output_ptr, input_ptr, start_offsets_ptr, + row_indices_ptr, row_size, col_size, dim_size, BLOCK_SIZE: tl.constexpr): # get program id and block offset pid = tl.program_id(0) @@ -28,48 +28,50 @@ def masked_index_copy_kernel(output_ptr, input_ptr, masked_m_ptr, offsets = block_start + tl.arange(0, BLOCK_SIZE) # compute mask and pointers + num_tokens = tl.load(start_offsets_ptr + row_size) token_idx = offsets // dim_size - row_idx = token_idx // col_size - col_idx = token_idx % col_size + valid = token_idx < num_tokens + row_idx = tl.load(row_indices_ptr + token_idx) + start_offset = tl.load(start_offsets_ptr + row_idx, mask=valid) + col_idx = token_idx - start_offset elem_idx = offsets % dim_size - num_cols = tl.load(masked_m_ptr + row_idx) - valid = col_idx < num_cols - # load start offset and input data - start_offset = tl.load(start_offsets_ptr + row_idx) - input_offset = (start_offset + col_idx) * dim_size + elem_idx - input = tl.load(input_ptr + input_offset, mask=valid) + # load input data + input = tl.load(input_ptr + offsets, mask=valid) # write output output_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx tl.store(output_ptr + output_offsets, input, mask=valid) -def triton_masked_index_copy(output, input, masked_m, start_offsets): - assert output.dtype == input.dtype, "Output and input must have the same dtype" +def triton_masked_index_copy(output, input, start_offsets, row_indices): assert output.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" assert input.ndim == 2, "Input must be a 2D tensor" + assert start_offsets.shape[ + 0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)" + num_tokens = input.shape[0] row_size = output.shape[0] col_size = output.shape[1] dim_size = output.shape[2] - total_elems = row_size * col_size * dim_size + total_elems = num_tokens * dim_size # launch kernel grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) masked_index_copy_kernel[grid](output, input, - masked_m, start_offsets, + row_indices, + row_size, col_size, dim_size, BLOCK_SIZE=1024) - return output + return @triton.jit -def masked_index_gather_kernel(output_ptr, input_ptr, masked_m_ptr, - start_offsets_ptr, col_size, dim_size, +def masked_index_gather_kernel(output_ptr, input_ptr, start_offsets_ptr, + row_indices_ptr, row_size, col_size, dim_size, BLOCK_SIZE: tl.constexpr): # get program id and block offset pid = tl.program_id(0) @@ -77,44 +79,46 @@ def masked_index_gather_kernel(output_ptr, input_ptr, masked_m_ptr, offsets = block_start + tl.arange(0, BLOCK_SIZE) # compute mask and pointers + num_tokens = tl.load(start_offsets_ptr + row_size) token_idx = offsets // dim_size - row_idx = token_idx // col_size - col_idx = token_idx % col_size + valid = token_idx < num_tokens + row_idx = tl.load(row_indices_ptr + token_idx) + start_offset = tl.load(start_offsets_ptr + row_idx, mask=valid) + col_idx = token_idx - start_offset elem_idx = offsets % dim_size - num_cols = tl.load(masked_m_ptr + row_idx) - valid = col_idx < num_cols # input data input_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx input_vals = tl.load(input_ptr + input_offsets, mask=valid) # get gather indices and store to output - start_offset = tl.load(start_offsets_ptr + row_idx) - gather_offset = (start_offset + col_idx) * dim_size + elem_idx - tl.store(output_ptr + gather_offset, input_vals, mask=valid) + tl.store(output_ptr + offsets, input_vals, mask=valid) @torch.no_grad() -def triton_masked_index_gather(output, input, masked_m, start_offsets): +def triton_masked_index_gather(output, input, start_offsets, row_indices): assert output.ndim == 2, "Output must be a 2D tensor" assert input.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" - assert masked_m.ndim == 1, "Indices must be a 1D tensor" + assert start_offsets.shape[ + 0] == input.shape[0] + 1, "Start offsets must be (num_experts + 1)" row_size = input.shape[0] col_size = input.shape[1] dim_size = input.shape[2] - total_elems = row_size * col_size * dim_size + num_tokens = output.shape[0] + total_elems = num_tokens * dim_size # launch kernel grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) masked_index_gather_kernel[grid](output, input, - masked_m, start_offsets, + row_indices, + row_size, col_size, dim_size, BLOCK_SIZE=1024) - return output + return @nvtx_range("[DG] act") @@ -130,16 +134,17 @@ def indexing(x, mask): return x[mask > 0, :].contiguous() -@nvtx_range("[DG] copy after permute") -@torch.compile(dynamic=True) -def preprocess_after_permute(expert_first_token_offset_tensor, ): +@nvtx_range("[DG] preprocess_after_permute") +def preprocess_after_permute(expert_first_token_offset_tensor, + permuted_data_tensor): # get tokens per expert masked_m = expert_first_token_offset_tensor[ 1:] - expert_first_token_offset_tensor[:-1] - masked_m_shift = torch.zeros_like(masked_m) - masked_m_shift[1:] = masked_m[:-1] - start_offsets = torch.cumsum(masked_m_shift, dim=0) - return masked_m.to(torch.int32), start_offsets + token_to_expert_map = torch.searchsorted( + expert_first_token_offset_tensor[1:], + torch.arange(permuted_data_tensor.shape[0], device='cuda'), + right=True) + return masked_m.to(torch.int32), token_to_expert_map @nvtx_range("[DG]") @@ -332,8 +337,8 @@ def forward_chunk( dtype=self.dtype, device='cuda') - masked_m, start_offsets = preprocess_after_permute( - expert_first_token_offset_tensor) + masked_m, token_to_expert_map = preprocess_after_permute( + expert_first_token_offset_tensor, permuted_data_tensor) m_max = (x.shape[0] + 127) // 128 * 128 expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - @@ -345,7 +350,9 @@ def forward_chunk( dtype=self.dtype, device='cuda') triton_masked_index_copy(permuted_data_tensor_padded, - permuted_data_tensor, masked_m, start_offsets) + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map) act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( permuted_data_tensor_padded) @@ -368,8 +375,9 @@ def forward_chunk( expected_m=expected_m, ) - triton_masked_index_gather(permuted_data_tensor, h3, masked_m, - start_offsets) + triton_masked_index_gather(permuted_data_tensor, h3, + expert_first_token_offset_tensor, + token_to_expert_map) final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( permuted_data_tensor, From 0af69acccfdd7fab41b69fd86117750c08719d79 Mon Sep 17 00:00:00 2001 From: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:03:02 +0800 Subject: [PATCH 16/38] Fix adp for deepgemm moe backend (#10) Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 964e51e50e0..ade55455606 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -11,6 +11,7 @@ import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._utils import nvtx_range +from ...distributed import allgather from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE @@ -301,6 +302,14 @@ def forward_chunk( f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" ) + use_allgather = self.use_dp and self.parallel_size > 1 + if use_allgather: + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + ( permuted_row_to_unpermuted_row_tensor, permuted_token_selected_experts_tensor, @@ -330,15 +339,9 @@ def forward_chunk( if permuted_data_tensor.numel() == 0: return torch.zeros_like(x) - max_padded_tokens = (x.shape[0] + 128) // 128 * 128 - permuted_data_tensor_padded = torch.empty( - (self.expert_size_per_partition, max_padded_tokens, - self.hidden_size), - dtype=self.dtype, - device='cuda') - masked_m, token_to_expert_map = preprocess_after_permute( expert_first_token_offset_tensor, permuted_data_tensor) + m_max = (x.shape[0] + 127) // 128 * 128 expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - From 6f431f646bbe687a6738212550954bf3332e2fae Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 21 Jul 2025 05:11:58 +0000 Subject: [PATCH 17/38] Use DeepGEMM main branch instead. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- requirements.txt | 2 +- .../modules/fused_moe/fused_moe_deepgemm.py | 51 +------ tensorrt_llm/quantization/utils/fp8_utils.py | 128 +----------------- 3 files changed, 5 insertions(+), 176 deletions(-) diff --git a/requirements.txt b/requirements.txt index efb10ecf159..a87e7f82c27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/RayWang96/DeepGEMM.git@cc416ee +deep_gemm @ git+https://github.com/deepseek-ai/DeepGEMM.git@187656694f7f69e3e7975617a68bc3387680a7e1 diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index ade55455606..ce456590824 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,12 +1,10 @@ -import functools from typing import List, Optional, Union +import deep_gemm import torch import torch.nn.functional as F import triton import triton.language as tl -from deep_gemm.jit_kernels.impls import sm100_fp8_gemm_1d1d -from deep_gemm.utils.layout import MajorTypeAB import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._utils import nvtx_range @@ -157,54 +155,11 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: - d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), device=b.device, dtype=torch.bfloat16) - compiled_dims = 'nk' - - # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` - assert a.stride(-1) == 1 - assert b.stride(-1) == 1 - assert masked_m.is_contiguous() - - num_groups, m, k = a.shape - num_groups_, n, k_ = b.shape - num_groups__, m_, n_ = d.shape - num_groups___ = masked_m.numel() - - # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ - assert m == m_ and n == n_ and k == k_ - assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 - assert a.dtype == torch.float8_e4m3fn - assert b.dtype == torch.float8_e4m3fn - assert d.dtype == torch.bfloat16 - assert masked_m.dtype == torch.int32 - - # D must be N-major - assert d.stride(-1) == 1 - - # Transform SFA and SFB into compute-required layout - recipe = (1, 128, 128) - sfa = fp8_utils.transform_sf_into_required_layout(sfa, - mn=m, - k=k, - recipe=recipe, - num_groups=num_groups, - is_sfa=True) - sfb = fp8_utils.transform_sf_into_required_layout(sfb, - mn=n, - k=k, - recipe=recipe, - num_groups=num_groups, - is_sfa=False) - - impl = functools.partial(sm100_fp8_gemm_1d1d.fp8_m_grouped_gemm_nt_masked, - major_a=MajorTypeAB.KMajor, - major_b=MajorTypeAB.KMajor, - compiled_dims=compiled_dims) - impl(a, sfa, b, sfb, d, masked_m, expected_m) + deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), d, masked_m, + expected_m) return d diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 5d277c8b828..41f4314822f 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Tuple import torch @@ -88,129 +88,3 @@ def resmooth_to_fp8_e8m0(weight: torch.Tensor, x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] return per_block_cast_to_fp8_e8m0(x) - - -def get_m_alignment_for_contiguous_layout(): - return 128 - - -def get_tma_aligned_size(x: int, element_size: int) -> int: - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return align(x, alignment) - - -def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: - # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA - assert x.dtype == torch.float and x.dim() in (2, 3) - - # First, convert into UE8M0 `uint8_t` - ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) - - # Second, make padded packed tensors - mn, k = x.shape[-2], x.shape[-1] - remove_dim = False - if x.dim() == 2: - x, remove_dim = x.unsqueeze(0), True - b = x.shape[0] - aligned_mn = get_tma_aligned_size(mn, 4) - aligned_k = align(k, 4) - padded = torch.zeros((b, aligned_mn, aligned_k), - device=x.device, - dtype=torch.uint8) - padded[:, :mn, :k] = ue8m0_tensor - padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, - aligned_k // 4) - - # Finally, transpose - transposed = torch.transpose( - torch.empty((b, aligned_k // 4, aligned_mn), - device=x.device, - dtype=torch.int), 1, 2) - transposed[:, :, :] = padded - aligned_x = transposed[:, :mn, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x - - -def check_sf_layout(sf: torch.Tensor, - mn: int, - k: int, - gran: Tuple[int, int], - num_groups: Optional[int], - tma_stride_check: bool = False, - type_check: Optional[torch.dtype] = None) -> torch.Tensor: - # Type check - if type_check is not None: - assert sf.dtype == type_check - - # Always do shape checks - assert sf.dtype in (torch.float, torch.int) - assert sf.dim() == int(num_groups is not None) + 2 - if num_groups is not None: - assert sf.size(-3) == num_groups - assert sf.size(-2) == ceil_div(mn, gran[0]) - assert sf.size(-1) == ceil_div( - k, gran[1] * (1 if sf.dtype == torch.float else 4)) - - # TMA stride checks: TMA aligned and MN-major - if tma_stride_check: - if num_groups is not None: - assert sf.stride(-3) == sf.stride(-1) * sf.size(-1) - assert sf.stride(-2) == 1 - assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()) - - return sf - - -@nvtx_range("[DG] transform_sf_into_required_layout") -@torch.compile(dynamic=True) -def transform_sf_into_required_layout(sf: torch.Tensor, - mn: int, - k: int, - recipe: Tuple[int, int, int], - num_groups: Optional[int] = None, - is_sfa: bool = False): - gran = (recipe[0 if is_sfa else 1], recipe[2]) - - should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128)) - or (sf.dtype == torch.int and gran == (128, 128))) - - if not should_skip_transform: - # Pre-transform checks - check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups) - - # (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if sf.dtype == torch.float and gran == (1, 128): - sf = get_col_major_tma_aligned_packed_tensor(sf) - return check_sf_layout(sf, - mn=mn, - k=k, - gran=(1, 128), - num_groups=num_groups, - tma_stride_check=True, - type_check=torch.int) - - # (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if sf.dtype == torch.float and gran == (128, 128): - sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) - sf = get_col_major_tma_aligned_packed_tensor(sf) - return check_sf_layout(sf, - mn=mn, - k=k, - gran=(1, 128), - num_groups=num_groups, - tma_stride_check=True, - type_check=torch.int) - - if should_skip_transform: - # TODO: add transpose kernel if SF layout is not satisfied - return check_sf_layout(sf, - mn=mn, - k=k, - gran=(1, 128), - num_groups=num_groups, - tma_stride_check=True, - type_check=torch.int) - - assert False, f'Unknown cases: {sf.dtype=}, {gran=}' From 481fd500b6cfd3b61bbc6343d080e9e6d3258afa Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:51:48 +0800 Subject: [PATCH 18/38] Revert "Use DeepGEMM main branch instead." Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- requirements.txt | 2 +- .../modules/fused_moe/fused_moe_deepgemm.py | 51 ++++++- tensorrt_llm/quantization/utils/fp8_utils.py | 128 +++++++++++++++++- 3 files changed, 176 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index a87e7f82c27..efb10ecf159 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/deepseek-ai/DeepGEMM.git@187656694f7f69e3e7975617a68bc3387680a7e1 +deep_gemm @ git+https://github.com/RayWang96/DeepGEMM.git@cc416ee diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index ce456590824..ade55455606 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,10 +1,12 @@ +import functools from typing import List, Optional, Union -import deep_gemm import torch import torch.nn.functional as F import triton import triton.language as tl +from deep_gemm.jit_kernels.impls import sm100_fp8_gemm_1d1d +from deep_gemm.utils.layout import MajorTypeAB import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._utils import nvtx_range @@ -155,11 +157,54 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: + d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), device=b.device, dtype=torch.bfloat16) - deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), d, masked_m, - expected_m) + compiled_dims = 'nk' + + # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` + assert a.stride(-1) == 1 + assert b.stride(-1) == 1 + assert masked_m.is_contiguous() + + num_groups, m, k = a.shape + num_groups_, n, k_ = b.shape + num_groups__, m_, n_ = d.shape + num_groups___ = masked_m.numel() + + # Type and shape checks + assert num_groups == num_groups_ == num_groups__ == num_groups___ + assert m == m_ and n == n_ and k == k_ + assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + assert d.dtype == torch.bfloat16 + assert masked_m.dtype == torch.int32 + + # D must be N-major + assert d.stride(-1) == 1 + + # Transform SFA and SFB into compute-required layout + recipe = (1, 128, 128) + sfa = fp8_utils.transform_sf_into_required_layout(sfa, + mn=m, + k=k, + recipe=recipe, + num_groups=num_groups, + is_sfa=True) + sfb = fp8_utils.transform_sf_into_required_layout(sfb, + mn=n, + k=k, + recipe=recipe, + num_groups=num_groups, + is_sfa=False) + + impl = functools.partial(sm100_fp8_gemm_1d1d.fp8_m_grouped_gemm_nt_masked, + major_a=MajorTypeAB.KMajor, + major_b=MajorTypeAB.KMajor, + compiled_dims=compiled_dims) + impl(a, sfa, b, sfb, d, masked_m, expected_m) return d diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 41f4314822f..5d277c8b828 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch @@ -88,3 +88,129 @@ def resmooth_to_fp8_e8m0(weight: torch.Tensor, x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] return per_block_cast_to_fp8_e8m0(x) + + +def get_m_alignment_for_contiguous_layout(): + return 128 + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return align(x, alignment) + + +def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), + device=x.device, + dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, + aligned_k // 4) + + # Finally, transpose + transposed = torch.transpose( + torch.empty((b, aligned_k // 4, aligned_mn), + device=x.device, + dtype=torch.int), 1, 2) + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def check_sf_layout(sf: torch.Tensor, + mn: int, + k: int, + gran: Tuple[int, int], + num_groups: Optional[int], + tma_stride_check: bool = False, + type_check: Optional[torch.dtype] = None) -> torch.Tensor: + # Type check + if type_check is not None: + assert sf.dtype == type_check + + # Always do shape checks + assert sf.dtype in (torch.float, torch.int) + assert sf.dim() == int(num_groups is not None) + 2 + if num_groups is not None: + assert sf.size(-3) == num_groups + assert sf.size(-2) == ceil_div(mn, gran[0]) + assert sf.size(-1) == ceil_div( + k, gran[1] * (1 if sf.dtype == torch.float else 4)) + + # TMA stride checks: TMA aligned and MN-major + if tma_stride_check: + if num_groups is not None: + assert sf.stride(-3) == sf.stride(-1) * sf.size(-1) + assert sf.stride(-2) == 1 + assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()) + + return sf + + +@nvtx_range("[DG] transform_sf_into_required_layout") +@torch.compile(dynamic=True) +def transform_sf_into_required_layout(sf: torch.Tensor, + mn: int, + k: int, + recipe: Tuple[int, int, int], + num_groups: Optional[int] = None, + is_sfa: bool = False): + gran = (recipe[0 if is_sfa else 1], recipe[2]) + + should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128)) + or (sf.dtype == torch.int and gran == (128, 128))) + + if not should_skip_transform: + # Pre-transform checks + check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups) + + # (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if sf.dtype == torch.float and gran == (1, 128): + sf = get_col_major_tma_aligned_packed_tensor(sf) + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + # (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if sf.dtype == torch.float and gran == (128, 128): + sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) + sf = get_col_major_tma_aligned_packed_tensor(sf) + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + if should_skip_transform: + # TODO: add transpose kernel if SF layout is not satisfied + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + assert False, f'Unknown cases: {sf.dtype=}, {gran=}' From ab7175f242f723971e6053bf7019bdc7503c0a2a Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:54:21 +0800 Subject: [PATCH 19/38] Use DeepGEMM main branch and disable ue8m0 cast. (#16) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- requirements.txt | 2 +- .../modules/fused_moe/fused_moe_deepgemm.py | 16 ++++++---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index efb10ecf159..a87e7f82c27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/RayWang96/DeepGEMM.git@cc416ee +deep_gemm @ git+https://github.com/deepseek-ai/DeepGEMM.git@187656694f7f69e3e7975617a68bc3387680a7e1 diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index ade55455606..829eb998e64 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,12 +1,10 @@ -import functools from typing import List, Optional, Union +import deep_gemm import torch import torch.nn.functional as F import triton import triton.language as tl -from deep_gemm.jit_kernels.impls import sm100_fp8_gemm_1d1d -from deep_gemm.utils.layout import MajorTypeAB import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._utils import nvtx_range @@ -157,11 +155,9 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: - d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), device=b.device, dtype=torch.bfloat16) - compiled_dims = 'nk' # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` assert a.stride(-1) == 1 @@ -200,11 +196,11 @@ def deepgemm_fp8_group_blockwise_gemm( num_groups=num_groups, is_sfa=False) - impl = functools.partial(sm100_fp8_gemm_1d1d.fp8_m_grouped_gemm_nt_masked, - major_a=MajorTypeAB.KMajor, - major_b=MajorTypeAB.KMajor, - compiled_dims=compiled_dims) - impl(a, sfa, b, sfb, d, masked_m, expected_m) + deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), + d, + masked_m, + expected_m, + disable_ue8m0_cast=True) return d From 97a21fdfa94dd2f89ba3bc54796caff0ff0620dc Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 21 Jul 2025 00:40:29 -0700 Subject: [PATCH 20/38] fuse maskec index_copy and grouped fp8 quantization. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 123 +++++++++++++----- 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 829eb998e64..5e01b081be3 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -18,34 +18,77 @@ @triton.jit -def masked_index_copy_kernel(output_ptr, input_ptr, start_offsets_ptr, - row_indices_ptr, row_size, col_size, dim_size, - BLOCK_SIZE: tl.constexpr): +def _masked_index_copy_group_quant_fp8( + input_ptr, + out_q_ptr, + out_s_ptr, + # mask indices + start_offsets_ptr, + row_indices_ptr, + # group size + group_size, + # output size + row_size, + col_size, + dim_size, + # avoid to divide zero + eps, + # block size + BLOCK: tl.constexpr, +): # get program id and block offset pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) + block_start = pid * group_size # compute mask and pointers + offsets = block_start + tl.arange(0, BLOCK) + mask = offsets < (block_start + group_size) num_tokens = tl.load(start_offsets_ptr + row_size) token_idx = offsets // dim_size - valid = token_idx < num_tokens - row_idx = tl.load(row_indices_ptr + token_idx) + valid = (token_idx < num_tokens) & mask + row_idx = tl.load(row_indices_ptr + token_idx, mask=valid) start_offset = tl.load(start_offsets_ptr + row_idx, mask=valid) col_idx = token_idx - start_offset elem_idx = offsets % dim_size # load input data - input = tl.load(input_ptr + offsets, mask=valid) - - # write output - output_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx - tl.store(output_ptr + output_offsets, input, mask=valid) + input = tl.load(input_ptr + offsets, mask=valid, other=0.0).to(tl.float32) + # quant + _absmax = tl.maximum(tl.max(tl.abs(input)), eps) + output_s = _absmax / 448.0 + output_s_inv = 1.0 / output_s + output_q = tl.clamp(input * output_s_inv, -448.0, + 448.0).to(out_q_ptr.dtype.element_ty) + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) -def triton_masked_index_copy(output, input, start_offsets, row_indices): - assert output.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" + # write output + s_dim_size = dim_size // group_size + out_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx + group_in_token = elem_idx // group_size + out_s_offset = row_idx * col_size * s_dim_size + col_idx * s_dim_size + group_in_token + + # Only store scaling factor for the first element in each group to avoid race conditions + is_first_in_group = elem_idx % group_size == 0 + tl.store(out_q_ptr + out_offsets, output_q, mask=valid) + tl.store(out_s_ptr + out_s_offset, output_s, mask=valid & is_first_in_group) + + +def masked_index_copy_group_quant_fp8( + output: torch.Tensor, + output_s: torch.Tensor, + input: torch.Tensor, + start_offsets: torch.Tensor, + row_indices: torch.Tensor, + group_size: int, + eps: float = 1e-10, +): + assert ( + input.shape[-1] % group_size == 0 + ), "the last dimension of `input` cannot be divisible by `group_size`" + assert input.is_contiguous(), "`input` is not contiguous" assert input.ndim == 2, "Input must be a 2D tensor" + assert output.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" assert start_offsets.shape[ 0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)" @@ -55,16 +98,24 @@ def triton_masked_index_copy(output, input, start_offsets, row_indices): dim_size = output.shape[2] total_elems = num_tokens * dim_size - # launch kernel - grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) - masked_index_copy_kernel[grid](output, - input, - start_offsets, - row_indices, - row_size, - col_size, - dim_size, - BLOCK_SIZE=1024) + M = total_elems // group_size + BLOCK = triton.next_power_of_2(group_size) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + _masked_index_copy_group_quant_fp8[(M, )]( + input, + output, + output_s, + start_offsets, + row_indices, + group_size, + row_size, + col_size, + dim_size, + eps, + BLOCK=BLOCK, + num_warps=num_warps, + ) return @@ -342,19 +393,21 @@ def forward_chunk( expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition - permuted_data_tensor_padded = torch.empty( - self.expert_size_per_partition, - m_max, - self.hidden_size, - dtype=self.dtype, + act_input_fp8 = torch.empty( + (self.expert_size_per_partition, m_max, self.hidden_size), + dtype=torch.float8_e4m3fn, + device='cuda') + act_input_sf = torch.empty( + (self.expert_size_per_partition, m_max, self.hidden_size // 128), + dtype=torch.float32, device='cuda') - triton_masked_index_copy(permuted_data_tensor_padded, - permuted_data_tensor, - expert_first_token_offset_tensor, - token_to_expert_map) + masked_index_copy_group_quant_fp8(act_input_fp8, + act_input_sf, + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map, + group_size=128) - act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0( - permuted_data_tensor_padded) h1 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w3_w1_weight, From f668fa78ba5b3c0a78a4e893f6c4cec2813afb1b Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 21 Jul 2025 05:33:45 -0700 Subject: [PATCH 21/38] fix quantization accuracy issue. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 5e01b081be3..f2b7275c6d5 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -57,10 +57,10 @@ def _masked_index_copy_group_quant_fp8( # quant _absmax = tl.maximum(tl.max(tl.abs(input)), eps) output_s = _absmax / 448.0 + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) output_s_inv = 1.0 / output_s output_q = tl.clamp(input * output_s_inv, -448.0, 448.0).to(out_q_ptr.dtype.element_ty) - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) # write output s_dim_size = dim_size // group_size From c6b898556c2e699f2d1b10d73d07a674774dced1 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 21 Jul 2025 22:57:54 +0800 Subject: [PATCH 22/38] Fuse swiglu and quant 2 (#18) Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 18 ++- tensorrt_llm/quantization/utils/fp8_utils.py | 146 ++++++++++++++++++ 2 files changed, 162 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index f2b7275c6d5..25c844eba04 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -416,8 +416,22 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - h2 = swiglu_fused_moe(h1) - act_input_fp8, act_input_sf = fp8_utils.per_token_cast_to_fp8_e8m0(h2) + act_input_fp8 = torch.empty(h1.shape[0], + h1.shape[1], + h1.shape[2] // 2, + dtype=torch.float8_e4m3fn, + device='cuda') + act_input_sf = torch.empty(h1.shape[0], + h1.shape[1], + h1.shape[2] // 256, + dtype=torch.float32, + device='cuda') + fp8_utils.silu_and_mul_masked_post_quant_fwd(input=h1, + output=act_input_fp8, + output_scale=act_input_sf, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) h3 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w2_weight, diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 5d277c8b828..5e6c00b996b 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -1,6 +1,8 @@ from typing import Optional, Tuple import torch +import triton +import triton.language as tl from tensorrt_llm._utils import nvtx_range @@ -214,3 +216,147 @@ def transform_sf_into_required_layout(sf: torch.Tensor, type_check=torch.int) assert False, f'Unknown cases: {sf.dtype=}, {gran=}' + + +# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py +@triton.jit +def _silu_and_mul_post_quant_kernel( + input_ptr, + stride_input_0, + stride_input_1, + stride_input_2, + output_ptr, + stride_output_0, + stride_output_1, + stride_output_2, + output_scale_ptr, + stride_output_scale_0, + stride_output_scale_1, + stride_output_scale_2, + masked_m_ptr, + size_n, + fp8_max, + fp8_min, + BLOCK_N: tl.constexpr, + NUM_STAGE: tl.constexpr, + SCALE_UE8M0: tl.constexpr, +): + expert_id = tl.program_id(2) + token_id = tl.program_id(1) + hidden_dim_block_index = tl.program_id(0) + + block_num_per_expert = tl.num_programs(1) + + token_num_cur_expert = tl.load(masked_m_ptr + expert_id) + + stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) + stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) + stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) + stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) + + offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) + input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d + output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d + output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 + + hidden_dim_block_index * stride_output_scale_2) + + for token_index in tl.range(token_id, + token_num_cur_expert, + block_num_per_expert, + num_stages=NUM_STAGE): + up = tl.load( + input_ptr_offs + token_index * stride_input_1, + mask=offs_in_d < size_n, + other=0.0, + ) + gate = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_n, + mask=offs_in_d < size_n, + other=0.0, + ).to(tl.float32) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(gate_up / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + tl.store( + output_ptr_offs + token_index * stride_output_1, + output_q, + mask=offs_in_d < size_n, + ) + tl.store( + output_scale_offs + token_index * stride_output_scale_1, + output_s, + ) + + +def silu_and_mul_masked_post_quant_fwd( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + scale_ue8m0: bool = False, +): + """ + input shape [expert_num, token_num_padded, hidden_dim] + output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 + output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 + quant_group_size int, + masked_m shape [expert_num], + """ + + assert input.is_contiguous() + assert output.dtype == torch.float8_e4m3fn + assert output.is_contiguous() + assert len(input.shape) == 3 + assert input.shape[0] == masked_m.shape[0] + assert input.shape[-1] % 2 == 0 + + size_n = input.shape[-1] // 2 + assert size_n % quant_group_size == 0 + + expert_num = len(masked_m) + + if expert_num < 4: + BLOCK_NUM_PER_EXPERT = 64 + else: + BLOCK_NUM_PER_EXPERT = 32 + + BLOCK_N = quant_group_size + num_warps = 1 + NUM_STAGES = 6 + hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) + assert BLOCK_N % quant_group_size == 0 + + grid = ( + hidden_dim_split_block_num, + BLOCK_NUM_PER_EXPERT, + expert_num, + ) + + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = -fp8_max + + _silu_and_mul_post_quant_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + output_scale, + *output_scale.stride(), + masked_m, + size_n, + fp8_max, + fp8_min, + BLOCK_N=BLOCK_N, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + SCALE_UE8M0=scale_ue8m0, + ) + return From 11053b7f27dc40f5d25501d1c5f7810568fc259b Mon Sep 17 00:00:00 2001 From: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Date: Thu, 24 Jul 2025 10:10:08 +0800 Subject: [PATCH 23/38] Opt gather kernel (#19) Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 25c844eba04..63b90036dad 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -125,24 +125,29 @@ def masked_index_gather_kernel(output_ptr, input_ptr, start_offsets_ptr, BLOCK_SIZE: tl.constexpr): # get program id and block offset pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - # compute mask and pointers num_tokens = tl.load(start_offsets_ptr + row_size) - token_idx = offsets // dim_size - valid = token_idx < num_tokens + + token_idx = pid + valid_token = token_idx < num_tokens + if not valid_token: + return + row_idx = tl.load(row_indices_ptr + token_idx) - start_offset = tl.load(start_offsets_ptr + row_idx, mask=valid) + start_offset = tl.load(start_offsets_ptr + row_idx) col_idx = token_idx - start_offset - elem_idx = offsets % dim_size - # input data - input_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx - input_vals = tl.load(input_ptr + input_offsets, mask=valid) + # Process elements in blocks + for hidden_start in tl.range(0, dim_size, BLOCK_SIZE): + hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE) + valid_hidden = hidden_indices < dim_size - # get gather indices and store to output - tl.store(output_ptr + offsets, input_vals, mask=valid) + input_offset = row_idx * col_size * dim_size + col_idx * dim_size + hidden_indices + input_val = tl.load(input_ptr + input_offset, + mask=valid_hidden, + other=0.0) + + output_offset = pid * dim_size + hidden_indices + tl.store(output_ptr + output_offset, input_val, mask=valid_hidden) @torch.no_grad() @@ -156,10 +161,9 @@ def triton_masked_index_gather(output, input, start_offsets, row_indices): col_size = input.shape[1] dim_size = input.shape[2] num_tokens = output.shape[0] - total_elems = num_tokens * dim_size + grid = (num_tokens, ) # launch kernel - grid = lambda meta: (triton.cdiv(total_elems, meta['BLOCK_SIZE']), ) masked_index_gather_kernel[grid](output, input, start_offsets, From 017383668d48dd758844337b6a1764209424dc52 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Tue, 22 Jul 2025 17:39:14 -0700 Subject: [PATCH 24/38] optimize the perf of masked_index_copy_group_quant_fp8. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 113 ++++++++++-------- 1 file changed, 66 insertions(+), 47 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 63b90036dad..e836db69996 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -25,53 +25,60 @@ def _masked_index_copy_group_quant_fp8( # mask indices start_offsets_ptr, row_indices_ptr, - # group size - group_size, - # output size + # dimensions + num_groups, row_size, col_size, dim_size, - # avoid to divide zero + group_size, + # quantization parameters eps, + fp8_max, # block size BLOCK: tl.constexpr, + NUM_STAGE: tl.constexpr, ): - # get program id and block offset - pid = tl.program_id(0) - block_start = pid * group_size + group_block = tl.program_id(0) + token_block = tl.program_id(1) + block_num_per_token = tl.num_programs(1) - # compute mask and pointers - offsets = block_start + tl.arange(0, BLOCK) - mask = offsets < (block_start + group_size) + # calculate group and element offsets num_tokens = tl.load(start_offsets_ptr + row_size) - token_idx = offsets // dim_size - valid = (token_idx < num_tokens) & mask - row_idx = tl.load(row_indices_ptr + token_idx, mask=valid) - start_offset = tl.load(start_offsets_ptr + row_idx, mask=valid) - col_idx = token_idx - start_offset - elem_idx = offsets % dim_size - - # load input data - input = tl.load(input_ptr + offsets, mask=valid, other=0.0).to(tl.float32) - - # quant - _absmax = tl.maximum(tl.max(tl.abs(input)), eps) - output_s = _absmax / 448.0 - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) - output_s_inv = 1.0 / output_s - output_q = tl.clamp(input * output_s_inv, -448.0, - 448.0).to(out_q_ptr.dtype.element_ty) - - # write output - s_dim_size = dim_size // group_size - out_offsets = row_idx * col_size * dim_size + col_idx * dim_size + elem_idx - group_in_token = elem_idx // group_size - out_s_offset = row_idx * col_size * s_dim_size + col_idx * s_dim_size + group_in_token - - # Only store scaling factor for the first element in each group to avoid race conditions - is_first_in_group = elem_idx % group_size == 0 - tl.store(out_q_ptr + out_offsets, output_q, mask=valid) - tl.store(out_s_ptr + out_s_offset, output_s, mask=valid & is_first_in_group) + group_start = group_block * group_size + elem_offsets = group_start + tl.arange(0, BLOCK) + valid_elem = elem_offsets < (group_start + group_size) + input_ptr_offs = input_ptr + elem_offsets + row_indices_ptr_offs = row_indices_ptr + elem_offsets // dim_size + output_ptr_offs = out_q_ptr + elem_offsets + output_s_offs = out_s_ptr + group_block + + # process tokens + for token_index in tl.range(token_block, + num_tokens, + block_num_per_token, + num_stages=NUM_STAGE): + # load input and indices + input_data = tl.load(input_ptr_offs + token_index * dim_size, + mask=valid_elem, + other=0.0) + row_idx = tl.load(row_indices_ptr_offs + token_index, + mask=valid_elem, + other=0) + start_offset = tl.load(start_offsets_ptr + row_idx, + mask=valid_elem, + other=0) + idx = row_idx * col_size + token_index - start_offset + + # quantization + _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps) + output_s = _absmax / fp8_max + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(input_data / output_s, -fp8_max, + fp8_max).to(out_q_ptr.dtype.element_ty) + + # store quantized values and scaling factor + tl.store(output_ptr_offs + idx * dim_size, output_q, mask=valid_elem) + tl.store(output_s_offs + idx * num_groups, output_s, mask=valid_elem) def masked_index_copy_group_quant_fp8( @@ -88,32 +95,44 @@ def masked_index_copy_group_quant_fp8( ), "the last dimension of `input` cannot be divisible by `group_size`" assert input.is_contiguous(), "`input` is not contiguous" assert input.ndim == 2, "Input must be a 2D tensor" - assert output.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" + assert output.ndim == 3, "Output must be a 3D tensor, [row, col, dim]" assert start_offsets.shape[ 0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)" - num_tokens = input.shape[0] row_size = output.shape[0] col_size = output.shape[1] dim_size = output.shape[2] - total_elems = num_tokens * dim_size + num_groups = (dim_size + group_size - 1) // group_size + + # get block/grid/stage/warp + BLOCK = group_size + BLOCK_NUM_PER_TOKEN = 128 + NUM_STAGES = 2 + num_warps = 4 + grid = ( + num_groups, + BLOCK_NUM_PER_TOKEN, + ) - M = total_elems // group_size - BLOCK = triton.next_power_of_2(group_size) - # heuristics for number of warps - num_warps = min(max(BLOCK // 256, 1), 8) - _masked_index_copy_group_quant_fp8[(M, )]( + # FP8 quantization parameters + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + + _masked_index_copy_group_quant_fp8[grid]( input, output, output_s, start_offsets, row_indices, - group_size, + num_groups, row_size, col_size, dim_size, + group_size, eps, + fp8_max, BLOCK=BLOCK, + NUM_STAGE=NUM_STAGES, num_warps=num_warps, ) return From bd94e37082995e78263594891ba4c91482cae6ed Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Wed, 23 Jul 2025 08:39:33 -0700 Subject: [PATCH 25/38] fix duplicate load. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index e836db69996..5ea793565aa 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -40,7 +40,7 @@ def _masked_index_copy_group_quant_fp8( ): group_block = tl.program_id(0) token_block = tl.program_id(1) - block_num_per_token = tl.num_programs(1) + token_block_num = tl.num_programs(1) # calculate group and element offsets num_tokens = tl.load(start_offsets_ptr + row_size) @@ -48,25 +48,20 @@ def _masked_index_copy_group_quant_fp8( elem_offsets = group_start + tl.arange(0, BLOCK) valid_elem = elem_offsets < (group_start + group_size) input_ptr_offs = input_ptr + elem_offsets - row_indices_ptr_offs = row_indices_ptr + elem_offsets // dim_size output_ptr_offs = out_q_ptr + elem_offsets output_s_offs = out_s_ptr + group_block # process tokens for token_index in tl.range(token_block, num_tokens, - block_num_per_token, + token_block_num, num_stages=NUM_STAGE): # load input and indices input_data = tl.load(input_ptr_offs + token_index * dim_size, mask=valid_elem, other=0.0) - row_idx = tl.load(row_indices_ptr_offs + token_index, - mask=valid_elem, - other=0) - start_offset = tl.load(start_offsets_ptr + row_idx, - mask=valid_elem, - other=0) + row_idx = tl.load(row_indices_ptr + token_index) + start_offset = tl.load(start_offsets_ptr + row_idx) idx = row_idx * col_size + token_index - start_offset # quantization @@ -78,7 +73,7 @@ def _masked_index_copy_group_quant_fp8( # store quantized values and scaling factor tl.store(output_ptr_offs + idx * dim_size, output_q, mask=valid_elem) - tl.store(output_s_offs + idx * num_groups, output_s, mask=valid_elem) + tl.store(output_s_offs + idx * num_groups, output_s) def masked_index_copy_group_quant_fp8( @@ -99,6 +94,7 @@ def masked_index_copy_group_quant_fp8( assert start_offsets.shape[ 0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)" + num_tokens = input.shape[0] row_size = output.shape[0] col_size = output.shape[1] dim_size = output.shape[2] @@ -106,12 +102,17 @@ def masked_index_copy_group_quant_fp8( # get block/grid/stage/warp BLOCK = group_size - BLOCK_NUM_PER_TOKEN = 128 - NUM_STAGES = 2 - num_warps = 4 + if num_tokens <= 4096: + TOKEN_BLOCK_NUM = 128 + NUM_STAGES = 4 + num_warps = 2 + else: + TOKEN_BLOCK_NUM = 64 + NUM_STAGES = 6 + num_warps = 1 grid = ( num_groups, - BLOCK_NUM_PER_TOKEN, + TOKEN_BLOCK_NUM, ) # FP8 quantization parameters From f1d311503e9394434b7e4e6ad2b5c91c3eb34f40 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Wed, 23 Jul 2025 22:53:14 -0700 Subject: [PATCH 26/38] fuse scaling factor transform to _masked_index_copy_group_quant_fp8. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 5ea793565aa..8c596488300 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -26,11 +26,13 @@ def _masked_index_copy_group_quant_fp8( start_offsets_ptr, row_indices_ptr, # dimensions - num_groups, row_size, col_size, dim_size, group_size, + # output scale factor size + aligned_col, + aligned_dim, # quantization parameters eps, fp8_max, @@ -49,7 +51,8 @@ def _masked_index_copy_group_quant_fp8( valid_elem = elem_offsets < (group_start + group_size) input_ptr_offs = input_ptr + elem_offsets output_ptr_offs = out_q_ptr + elem_offsets - output_s_offs = out_s_ptr + group_block + output_s_offs = out_s_ptr + (group_block // 4) * aligned_col + shift = (group_block % 4) * 8 # process tokens for token_index in tl.range(token_block, @@ -63,6 +66,7 @@ def _masked_index_copy_group_quant_fp8( row_idx = tl.load(row_indices_ptr + token_index) start_offset = tl.load(start_offsets_ptr + row_idx) idx = row_idx * col_size + token_index - start_offset + idx_s = row_idx * aligned_dim * aligned_col + token_index - start_offset # quantization _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps) @@ -70,15 +74,16 @@ def _masked_index_copy_group_quant_fp8( output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) output_q = tl.clamp(input_data / output_s, -fp8_max, fp8_max).to(out_q_ptr.dtype.element_ty) + output_s = (output_s.to(tl.int32, bitcast=True) >> 23).to(tl.uint8) # store quantized values and scaling factor tl.store(output_ptr_offs + idx * dim_size, output_q, mask=valid_elem) - tl.store(output_s_offs + idx * num_groups, output_s) + tl.atomic_or(output_s_offs + idx_s, output_s << shift) def masked_index_copy_group_quant_fp8( output: torch.Tensor, - output_s: torch.Tensor, + # output_s: torch.Tensor, input: torch.Tensor, start_offsets: torch.Tensor, row_indices: torch.Tensor, @@ -100,6 +105,15 @@ def masked_index_copy_group_quant_fp8( dim_size = output.shape[2] num_groups = (dim_size + group_size - 1) // group_size + # create padded output_s + alignment = 4 + scale_dim = (dim_size + group_size - 1) // group_size + padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment + padded_col_size = (col_size + alignment - 1) // alignment * alignment + output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size), + dtype=torch.int32, + device='cuda') + # get block/grid/stage/warp BLOCK = group_size if num_tokens <= 4096: @@ -125,18 +139,20 @@ def masked_index_copy_group_quant_fp8( output_s, start_offsets, row_indices, - num_groups, row_size, col_size, dim_size, group_size, + padded_col_size, + padded_dim_size // 4, eps, fp8_max, BLOCK=BLOCK, NUM_STAGE=NUM_STAGES, num_warps=num_warps, ) - return + output_s = output_s.transpose(1, 2)[:, :col_size, :] + return output_s @triton.jit @@ -421,16 +437,12 @@ def forward_chunk( (self.expert_size_per_partition, m_max, self.hidden_size), dtype=torch.float8_e4m3fn, device='cuda') - act_input_sf = torch.empty( - (self.expert_size_per_partition, m_max, self.hidden_size // 128), - dtype=torch.float32, - device='cuda') - masked_index_copy_group_quant_fp8(act_input_fp8, - act_input_sf, - permuted_data_tensor, - expert_first_token_offset_tensor, - token_to_expert_map, - group_size=128) + act_input_sf = masked_index_copy_group_quant_fp8( + act_input_fp8, + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map, + group_size=128) h1 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, From acd43816f677dbb330f5223303919dbfa384694c Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Thu, 24 Jul 2025 00:09:37 -0700 Subject: [PATCH 27/38] fix. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 8c596488300..ac8687776e0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -83,7 +83,6 @@ def _masked_index_copy_group_quant_fp8( def masked_index_copy_group_quant_fp8( output: torch.Tensor, - # output_s: torch.Tensor, input: torch.Tensor, start_offsets: torch.Tensor, row_indices: torch.Tensor, From 2d5beab60e69944c5d6ab4bba314b68fe1117ee4 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Thu, 24 Jul 2025 02:54:55 -0700 Subject: [PATCH 28/38] add another for loop on the group dim. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 66 +++++++++++-------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index ac8687776e0..8a7a8459dd7 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -46,39 +46,43 @@ def _masked_index_copy_group_quant_fp8( # calculate group and element offsets num_tokens = tl.load(start_offsets_ptr + row_size) - group_start = group_block * group_size - elem_offsets = group_start + tl.arange(0, BLOCK) - valid_elem = elem_offsets < (group_start + group_size) - input_ptr_offs = input_ptr + elem_offsets - output_ptr_offs = out_q_ptr + elem_offsets - output_s_offs = out_s_ptr + (group_block // 4) * aligned_col - shift = (group_block % 4) * 8 + elem_offsets = group_block * group_size * 4 + tl.arange(0, BLOCK) + output_s_offs = out_s_ptr + group_block * aligned_col # process tokens for token_index in tl.range(token_block, num_tokens, token_block_num, num_stages=NUM_STAGE): - # load input and indices - input_data = tl.load(input_ptr_offs + token_index * dim_size, - mask=valid_elem, - other=0.0) + # load indices row_idx = tl.load(row_indices_ptr + token_index) start_offset = tl.load(start_offsets_ptr + row_idx) idx = row_idx * col_size + token_index - start_offset idx_s = row_idx * aligned_dim * aligned_col + token_index - start_offset - # quantization - _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps) - output_s = _absmax / fp8_max - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) - output_q = tl.clamp(input_data / output_s, -fp8_max, - fp8_max).to(out_q_ptr.dtype.element_ty) - output_s = (output_s.to(tl.int32, bitcast=True) >> 23).to(tl.uint8) - - # store quantized values and scaling factor - tl.store(output_ptr_offs + idx * dim_size, output_q, mask=valid_elem) - tl.atomic_or(output_s_offs + idx_s, output_s << shift) + output_s_int32 = 0 + for group_index in tl.range(4): + # load input data + dim_offset = elem_offsets + group_index * group_size + valid = dim_offset < dim_size + input_data = tl.load(input_ptr + token_index * dim_size + + dim_offset, + mask=valid, + other=0.0) + # quantization + _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps) + output_s = _absmax / fp8_max + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(input_data / output_s, -fp8_max, + fp8_max).to(out_q_ptr.dtype.element_ty) + output_s = output_s.to(tl.int32, bitcast=True) >> 23 + output_s_int32 += output_s << (group_index * 8) + + # store quantized values and scaling factor + tl.store(out_q_ptr + idx * dim_size + dim_offset, + output_q, + mask=valid) + tl.store(output_s_offs + idx_s, output_s_int32) def masked_index_copy_group_quant_fp8( @@ -102,7 +106,6 @@ def masked_index_copy_group_quant_fp8( row_size = output.shape[0] col_size = output.shape[1] dim_size = output.shape[2] - num_groups = (dim_size + group_size - 1) // group_size # create padded output_s alignment = 4 @@ -114,17 +117,22 @@ def masked_index_copy_group_quant_fp8( device='cuda') # get block/grid/stage/warp + num_groups = (dim_size + group_size - 1) // group_size BLOCK = group_size - if num_tokens <= 4096: - TOKEN_BLOCK_NUM = 128 + if num_tokens <= 1000 or col_size <= 256: # Small workload + TOKEN_BLOCK_NUM = 256 NUM_STAGES = 4 num_warps = 2 - else: - TOKEN_BLOCK_NUM = 64 - NUM_STAGES = 6 + elif num_tokens <= 10000 or col_size <= 2048: # Medium workload + TOKEN_BLOCK_NUM = 1024 + NUM_STAGES = 2 + num_warps = 1 + else: # Large workload + TOKEN_BLOCK_NUM = 2048 + NUM_STAGES = 2 num_warps = 1 grid = ( - num_groups, + (num_groups + 3) // 4, TOKEN_BLOCK_NUM, ) From 5653eea07057339369e14b4bcc4cb17b791b5687 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:51:25 +0800 Subject: [PATCH 29/38] Remove SFB transform from forward process (#23) Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 15 +++++++++++- .../modules/fused_moe/fused_moe_deepgemm.py | 6 ----- .../_torch/modules/fused_moe/quantization.py | 24 ++++++++++++++++++- tensorrt_llm/_torch/modules/linear.py | 16 +++++++++---- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index a3d0dbe929c..9e689f44073 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -45,7 +45,8 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.utils.fp8_utils import resmooth_to_fp8_e8m0 +from tensorrt_llm.quantization.utils.fp8_utils import ( + resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -1417,6 +1418,18 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, for n, p in module.named_parameters(): p.data.copy_(module_weights[n][:]) + if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( + ) and get_sm_version() == 100 and hasattr( + module, "weight_scale"): + transfromed_scale = transform_sf_into_required_layout( + module.weight_scale, + mn=module.weight.shape[0], + k=module.weight.shape[1], + recipe=(1, 128, 128), + is_sfa=False) + module.weight_scale = nn.Parameter(transfromed_scale, + requires_grad=False) + for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 8a7a8459dd7..c7bc7b4f2a1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -287,12 +287,6 @@ def deepgemm_fp8_group_blockwise_gemm( recipe=recipe, num_groups=num_groups, is_sfa=True) - sfb = fp8_utils.transform_sf_into_required_layout(sfb, - mn=n, - k=k, - recipe=recipe, - num_groups=num_groups, - is_sfa=False) deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), d, diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index d66306f48f4..18e9c7cc98a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -9,7 +9,8 @@ from tensorrt_llm.quantization.utils.fp4_utils import ( float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices) -from tensorrt_llm.quantization.utils.fp8_utils import resmooth_to_fp8_e8m0 +from tensorrt_llm.quantization.utils.fp8_utils import ( + resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..linear import TensorParallelMode, load_weight_shard from .interface import MoEWeightLoadingMode @@ -485,6 +486,27 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], weight, scale) super().load_weights(module, weights, weight_loading_mode) + if get_sm_version() == 100: + transfromed_w3_w1_scale = transform_sf_into_required_layout( + module.quant_scales[0], + mn=module.w3_w1_weight.shape[1], + k=module.w3_w1_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w3_w1_weight_scaling_factor = nn.Parameter( + transfromed_w3_w1_scale, requires_grad=False) + transfromed_w2_scale = transform_sf_into_required_layout( + module.quant_scales[1], + mn=module.w2_weight.shape[1], + k=module.w2_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, + requires_grad=False) + self.setup_quant_scales(module) + def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( fc_weight_scales=module.w3_w1_weight_scaling_factor, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 9b98d6df9b4..b4dfb1d6456 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils +import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._torch.peft.lora.layer import LoraLayer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy) @@ -19,7 +20,6 @@ from tensorrt_llm.quantization.functional import \ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.mode import QuantAlgo -from tensorrt_llm.quantization.utils.fp8_utils import per_token_cast_to_fp8_e8m0 from ..._utils import get_sm_version from ...models.modeling_utils import QuantConfig @@ -574,12 +574,20 @@ def apply(self, module: Linear, input: torch.Tensor, if get_sm_version() == 100: import deep_gemm - a_tuple = per_token_cast_to_fp8_e8m0(input) + a, a_sf = fp8_utils.per_token_cast_to_fp8_e8m0(input) + a_sf = fp8_utils.transform_sf_into_required_layout(a_sf, + mn=a.shape[0], + k=a.shape[1], + recipe=(1, 128, + 128), + is_sfa=True) output = torch.empty((input.shape[0], module.weight.shape[0]), device=input.device, dtype=torch.bfloat16) - deep_gemm.fp8_gemm_nt(a_tuple, (module.weight, module.weight_scale), - output) + deep_gemm.fp8_gemm_nt((a, a_sf), + (module.weight, module.weight_scale), + output, + disable_ue8m0_cast=True) else: act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) From 49dcb983a7f9d25ab76c6d57d8983ce9ec8f4be8 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Fri, 25 Jul 2025 09:11:50 +0800 Subject: [PATCH 30/38] change deepgeem to a new commit that with torch dependency. (#24) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a87e7f82c27..6ea0c9af4d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/deepseek-ai/DeepGEMM.git@187656694f7f69e3e7975617a68bc3387680a7e1 +deep_gemm @ git+https://github.com/yuxianq/DeepGEMM.git@417c5924b0a2a9410b4a1368f06f63a195081911 From 99970069cf647796db69e7b66d8594c8b0a6b785 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Thu, 24 Jul 2025 19:57:41 -0700 Subject: [PATCH 31/38] fix format and rebase bug. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- examples/llm-api/quickstart_advanced.py | 12 +++++++----- .../_torch/models/checkpoints/hf/weight_loader.py | 2 -- tensorrt_llm/llmapi/llm_args.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 0aaf9e5200f..202d9944d8e 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -47,11 +47,13 @@ def add_llm_args(parser): 'VANILLA', 'TRTLLM', 'FLASHINFER', 'FLASHINFER_STAR_ATTENTION' ]) - parser.add_argument( - '--moe_backend', - type=str, - default='CUTLASS', - choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', 'DEEPGEMM', 'CUTEDSL']) + parser.add_argument('--moe_backend', + type=str, + default='CUTLASS', + choices=[ + 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', + 'DEEPGEMM', 'CUTEDSL' + ]) parser.add_argument('--enable_attention_dp', default=False, action='store_true') diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 2b90996cb6f..ba4703875e6 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -124,5 +124,3 @@ def prefetch_files(self, file_names: List[str]): len(local_file_names)) with multiprocessing.Pool(processes=max_processes) as pool: pool.map(self._prefetch_one_file, local_file_names) - # Ensure that all ranks have finished prefetching before loading weights - mpi_barrier() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6614391b452..54983fea2f2 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -93,7 +93,7 @@ class MoeConfig(BaseModel): """ Configuration for MoE. """ - backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", + backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", "VANILLA"] = Field(default='CUTLASS', description="MoE backend to use.") From d8ae02caed59708512e5aa578a09bb09f365bfed Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:00:12 +0800 Subject: [PATCH 32/38] fix dummy requests when estimate kv cache with attention DP enabled to avoid OOM (#25) Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4754e693fc5..3ad200ae0c4 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -143,6 +143,8 @@ def _create_dummy_context_requests( end_id=-1) requests.append(request) remaining_tokens -= input_seq_len + if self._mapping.enable_attention_dp: + requests = requests * self._mapping.tp_size return requests def _get_token_num_for_estimation(self) -> int: From fb3e46761982ca882a8b8e0a96c643bf3c9e298c Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:15:22 +0800 Subject: [PATCH 33/38] Fuse quantize and transform e8m0 scales (#26) Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 25 +- tensorrt_llm/_torch/modules/linear.py | 8 +- tensorrt_llm/quantization/utils/fp8_utils.py | 271 ++++++++++++++---- 3 files changed, 223 insertions(+), 81 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index c7bc7b4f2a1..3721a5d2afd 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -280,13 +280,6 @@ def deepgemm_fp8_group_blockwise_gemm( assert d.stride(-1) == 1 # Transform SFA and SFB into compute-required layout - recipe = (1, 128, 128) - sfa = fp8_utils.transform_sf_into_required_layout(sfa, - mn=m, - k=k, - recipe=recipe, - num_groups=num_groups, - is_sfa=True) deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), d, @@ -453,22 +446,8 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - act_input_fp8 = torch.empty(h1.shape[0], - h1.shape[1], - h1.shape[2] // 2, - dtype=torch.float8_e4m3fn, - device='cuda') - act_input_sf = torch.empty(h1.shape[0], - h1.shape[1], - h1.shape[2] // 256, - dtype=torch.float32, - device='cuda') - fp8_utils.silu_and_mul_masked_post_quant_fwd(input=h1, - output=act_input_fp8, - output_scale=act_input_sf, - quant_group_size=128, - masked_m=masked_m, - scale_ue8m0=True) + act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) h3 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w2_weight, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index b4dfb1d6456..1131fe42275 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -574,13 +574,7 @@ def apply(self, module: Linear, input: torch.Tensor, if get_sm_version() == 100: import deep_gemm - a, a_sf = fp8_utils.per_token_cast_to_fp8_e8m0(input) - a_sf = fp8_utils.transform_sf_into_required_layout(a_sf, - mn=a.shape[0], - k=a.shape[1], - recipe=(1, 128, - 128), - is_sfa=True) + a, a_sf = fp8_utils.per_token_quant_and_transform(input) output = torch.empty((input.shape[0], module.weight.shape[0]), device=input.device, dtype=torch.bfloat16) diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 5e6c00b996b..b653d10365c 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -234,10 +234,10 @@ def _silu_and_mul_post_quant_kernel( stride_output_scale_1, stride_output_scale_2, masked_m_ptr, - size_n, + size_k, fp8_max, fp8_min, - BLOCK_N: tl.constexpr, + BLOCK: tl.constexpr, NUM_STAGE: tl.constexpr, SCALE_UE8M0: tl.constexpr, ): @@ -254,109 +254,278 @@ def _silu_and_mul_post_quant_kernel( stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) - offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) + offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 + - hidden_dim_block_index * stride_output_scale_2) + hidden_dim_block_index * stride_output_scale_1) for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): - up = tl.load( - input_ptr_offs + token_index * stride_input_1, - mask=offs_in_d < size_n, - other=0.0, - ) - gate = tl.load( - input_ptr_offs + token_index * stride_input_1 + size_n, - mask=offs_in_d < size_n, - other=0.0, - ).to(tl.float32) - gate = gate / (1 + tl.exp(-gate)) - gate = gate.to(input_ptr.dtype.element_ty) - gate_up = up * gate - _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) - output_s = _absmax / fp8_max - if SCALE_UE8M0: - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) - output_q = tl.clamp(gate_up / output_s, fp8_min, - fp8_max).to(output_ptr.dtype.element_ty) - tl.store( - output_ptr_offs + token_index * stride_output_1, - output_q, - mask=offs_in_d < size_n, - ) + output_s_int32 = 0 + for pack_index in tl.range(4): + local_mask = offs_in_d + pack_index * 128 + up = tl.load( + input_ptr_offs + token_index * stride_input_1 + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ) + gate = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_k + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ).to(tl.float32) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(gate_up / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << + (8 * pack_index)) + tl.store( + output_ptr_offs + token_index * stride_output_1 + + pack_index * 128, + output_q, + mask=local_mask < size_k, + ) tl.store( - output_scale_offs + token_index * stride_output_scale_1, - output_s, + output_scale_offs + token_index * stride_output_scale_2, + output_s_int32, ) def silu_and_mul_masked_post_quant_fwd( input: torch.Tensor, - output: torch.Tensor, - output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, scale_ue8m0: bool = False, ): """ - input shape [expert_num, token_num_padded, hidden_dim] - output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 - output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 - quant_group_size int, - masked_m shape [expert_num], + input shape [g, m, k] + output shape [g, m, k // 2], dtype fp8 + output_scale [g, k // 4, m // 2 // 128], dtype int32 + quant_group_size int + masked_m shape [g] """ assert input.is_contiguous() - assert output.dtype == torch.float8_e4m3fn - assert output.is_contiguous() assert len(input.shape) == 3 assert input.shape[0] == masked_m.shape[0] assert input.shape[-1] % 2 == 0 - size_n = input.shape[-1] // 2 - assert size_n % quant_group_size == 0 + # FP8 quantization parameters + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = finfo.min + g, m, k = input.shape + k = k // 2 + + # Create output + output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") + + # Create output scale + alignment = 4 + scale_k = ceil_div(k, quant_group_size) + m_padded = align(m, alignment) + scale_k_padded = align(scale_k, alignment) + output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), + dtype=torch.int32, + device='cuda') + + # Get block/grid/stage/warp expert_num = len(masked_m) if expert_num < 4: BLOCK_NUM_PER_EXPERT = 64 else: - BLOCK_NUM_PER_EXPERT = 32 + BLOCK_NUM_PER_EXPERT = 128 - BLOCK_N = quant_group_size + BLOCK = quant_group_size * 4 num_warps = 1 NUM_STAGES = 6 - hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) - assert BLOCK_N % quant_group_size == 0 - + hidden_dim_split_block_num = triton.cdiv(k, BLOCK) grid = ( hidden_dim_split_block_num, BLOCK_NUM_PER_EXPERT, expert_num, ) + _silu_and_mul_post_quant_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + output_scale, + *output_scale.stride(), + masked_m, + k, + fp8_max, + fp8_min, + BLOCK=BLOCK, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + SCALE_UE8M0=scale_ue8m0, + ) + output_scale = output_scale.transpose(1, 2)[:, :m, :] + check_sf_layout( + output_scale, + m, + k, + (1, 128), + g, + tma_stride_check=True, + ) + return output, output_scale + + +@triton.jit +def _per_token_quant_and_transform_kernel( + input_ptr, + stride_input_0, + stride_input_1, + output_ptr, + stride_output_0, + stride_output_1, + output_scale_ptr, + stride_output_scale_0, + stride_output_scale_1, + token_num_cur_expert, + size_k, + fp8_max, + fp8_min, + BLOCK: tl.constexpr, + NUM_STAGE: tl.constexpr, + SCALE_UE8M0: tl.constexpr, +): + tl.program_id(2) + token_id = tl.program_id(1) + hidden_dim_block_index = tl.program_id(0) + + block_num_per_expert = tl.num_programs(1) + + stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) + stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) + stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) + stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) + + offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) + input_ptr_offs = input_ptr + offs_in_d + output_ptr_offs = output_ptr + offs_in_d + output_scale_offs = (output_scale_ptr + + hidden_dim_block_index * stride_output_scale_0) + + for token_index in tl.range(token_id, + token_num_cur_expert, + block_num_per_expert, + num_stages=NUM_STAGE): + output_s_int32 = 0 + for pack_index in tl.range(4): + local_mask = offs_in_d + pack_index * 128 + act = tl.load( + input_ptr_offs + token_index * stride_input_0 + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(act / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << + (8 * pack_index)) + tl.store( + output_ptr_offs + token_index * stride_output_0 + + pack_index * 128, + output_q, + mask=local_mask < size_k, + ) + tl.store( + output_scale_offs + token_index * stride_output_scale_1, + output_s_int32, + ) + + +def per_token_quant_and_transform( + input: torch.Tensor, + quant_group_size: int = 128, + scale_ue8m0: bool = True, +): + """ + input shape [g, m, k] + output shape [g, m, k // 2], dtype fp8 + output_scale [g, k // 4, m // 2 // 128], dtype int32 + quant_group_size int + masked_m shape [g] + """ + assert input.is_contiguous() + assert len(input.shape) == 2 + assert input.shape[-1] % 2 == 0 + + # FP8 quantization parameters finfo = torch.finfo(torch.float8_e4m3fn) fp8_max = finfo.max fp8_min = -fp8_max - _silu_and_mul_post_quant_kernel[grid]( + m, k = input.shape + + # Create output + output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda") + + # Create output scale + alignment = 4 + scale_k = ceil_div(k, quant_group_size) + m_padded = align(m, alignment) + scale_k_padded = align(scale_k, alignment) + output_scale = torch.zeros((scale_k_padded // 4, m_padded), + dtype=torch.int32, + device='cuda') + + # Get block/grid/stage/warp + BLOCK_NUM_PER_EXPERT = 64 + + BLOCK = quant_group_size * 4 + num_warps = 1 + NUM_STAGES = 6 + hidden_dim_split_block_num = triton.cdiv(k, BLOCK) + grid = ( + hidden_dim_split_block_num, + BLOCK_NUM_PER_EXPERT, + 1, + ) + _per_token_quant_and_transform_kernel[grid]( input, *input.stride(), output, *output.stride(), output_scale, *output_scale.stride(), - masked_m, - size_n, + m, + k, fp8_max, fp8_min, - BLOCK_N=BLOCK_N, + BLOCK=BLOCK, NUM_STAGE=NUM_STAGES, num_warps=num_warps, SCALE_UE8M0=scale_ue8m0, ) - return + output_scale = output_scale.transpose(0, 1)[:m, :] + check_sf_layout( + output_scale, + m, + k, + (1, 128), + num_groups=None, + tma_stride_check=True, + ) + return output, output_scale From 9107cfaa3a3af52b755c05ad1b29474520f7cb6d Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:54:22 +0800 Subject: [PATCH 34/38] Revert "Fuse quantize and transform e8m0 scales (#26)" (#27) This reverts commit fe01f0261145619025c939197985878516bcbfbb. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 25 +- tensorrt_llm/_torch/modules/linear.py | 8 +- tensorrt_llm/quantization/utils/fp8_utils.py | 271 ++++-------------- 3 files changed, 81 insertions(+), 223 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 3721a5d2afd..c7bc7b4f2a1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -280,6 +280,13 @@ def deepgemm_fp8_group_blockwise_gemm( assert d.stride(-1) == 1 # Transform SFA and SFB into compute-required layout + recipe = (1, 128, 128) + sfa = fp8_utils.transform_sf_into_required_layout(sfa, + mn=m, + k=k, + recipe=recipe, + num_groups=num_groups, + is_sfa=True) deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), d, @@ -446,8 +453,22 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( - input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) + act_input_fp8 = torch.empty(h1.shape[0], + h1.shape[1], + h1.shape[2] // 2, + dtype=torch.float8_e4m3fn, + device='cuda') + act_input_sf = torch.empty(h1.shape[0], + h1.shape[1], + h1.shape[2] // 256, + dtype=torch.float32, + device='cuda') + fp8_utils.silu_and_mul_masked_post_quant_fwd(input=h1, + output=act_input_fp8, + output_scale=act_input_sf, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) h3 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w2_weight, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 1131fe42275..b4dfb1d6456 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -574,7 +574,13 @@ def apply(self, module: Linear, input: torch.Tensor, if get_sm_version() == 100: import deep_gemm - a, a_sf = fp8_utils.per_token_quant_and_transform(input) + a, a_sf = fp8_utils.per_token_cast_to_fp8_e8m0(input) + a_sf = fp8_utils.transform_sf_into_required_layout(a_sf, + mn=a.shape[0], + k=a.shape[1], + recipe=(1, 128, + 128), + is_sfa=True) output = torch.empty((input.shape[0], module.weight.shape[0]), device=input.device, dtype=torch.bfloat16) diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index b653d10365c..5e6c00b996b 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -234,10 +234,10 @@ def _silu_and_mul_post_quant_kernel( stride_output_scale_1, stride_output_scale_2, masked_m_ptr, - size_k, + size_n, fp8_max, fp8_min, - BLOCK: tl.constexpr, + BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, SCALE_UE8M0: tl.constexpr, ): @@ -254,278 +254,109 @@ def _silu_and_mul_post_quant_kernel( stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) - offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) + offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 + - hidden_dim_block_index * stride_output_scale_1) + hidden_dim_block_index * stride_output_scale_2) for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): - output_s_int32 = 0 - for pack_index in tl.range(4): - local_mask = offs_in_d + pack_index * 128 - up = tl.load( - input_ptr_offs + token_index * stride_input_1 + - pack_index * 128, - mask=local_mask < size_k, - other=0.0, - ) - gate = tl.load( - input_ptr_offs + token_index * stride_input_1 + size_k + - pack_index * 128, - mask=local_mask < size_k, - other=0.0, - ).to(tl.float32) - gate = gate / (1 + tl.exp(-gate)) - gate = gate.to(input_ptr.dtype.element_ty) - gate_up = up * gate - _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) - output_s = _absmax / fp8_max - if SCALE_UE8M0: - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) - output_q = tl.clamp(gate_up / output_s, fp8_min, - fp8_max).to(output_ptr.dtype.element_ty) - output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << - (8 * pack_index)) - tl.store( - output_ptr_offs + token_index * stride_output_1 + - pack_index * 128, - output_q, - mask=local_mask < size_k, - ) + up = tl.load( + input_ptr_offs + token_index * stride_input_1, + mask=offs_in_d < size_n, + other=0.0, + ) + gate = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_n, + mask=offs_in_d < size_n, + other=0.0, + ).to(tl.float32) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(gate_up / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + tl.store( + output_ptr_offs + token_index * stride_output_1, + output_q, + mask=offs_in_d < size_n, + ) tl.store( - output_scale_offs + token_index * stride_output_scale_2, - output_s_int32, + output_scale_offs + token_index * stride_output_scale_1, + output_s, ) def silu_and_mul_masked_post_quant_fwd( input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, scale_ue8m0: bool = False, ): """ - input shape [g, m, k] - output shape [g, m, k // 2], dtype fp8 - output_scale [g, k // 4, m // 2 // 128], dtype int32 - quant_group_size int - masked_m shape [g] + input shape [expert_num, token_num_padded, hidden_dim] + output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 + output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 + quant_group_size int, + masked_m shape [expert_num], """ assert input.is_contiguous() + assert output.dtype == torch.float8_e4m3fn + assert output.is_contiguous() assert len(input.shape) == 3 assert input.shape[0] == masked_m.shape[0] assert input.shape[-1] % 2 == 0 - # FP8 quantization parameters - finfo = torch.finfo(torch.float8_e4m3fn) - fp8_max = finfo.max - fp8_min = finfo.min + size_n = input.shape[-1] // 2 + assert size_n % quant_group_size == 0 - g, m, k = input.shape - k = k // 2 - - # Create output - output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") - - # Create output scale - alignment = 4 - scale_k = ceil_div(k, quant_group_size) - m_padded = align(m, alignment) - scale_k_padded = align(scale_k, alignment) - output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), - dtype=torch.int32, - device='cuda') - - # Get block/grid/stage/warp expert_num = len(masked_m) if expert_num < 4: BLOCK_NUM_PER_EXPERT = 64 else: - BLOCK_NUM_PER_EXPERT = 128 + BLOCK_NUM_PER_EXPERT = 32 - BLOCK = quant_group_size * 4 + BLOCK_N = quant_group_size num_warps = 1 NUM_STAGES = 6 - hidden_dim_split_block_num = triton.cdiv(k, BLOCK) + hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) + assert BLOCK_N % quant_group_size == 0 + grid = ( hidden_dim_split_block_num, BLOCK_NUM_PER_EXPERT, expert_num, ) - _silu_and_mul_post_quant_kernel[grid]( - input, - *input.stride(), - output, - *output.stride(), - output_scale, - *output_scale.stride(), - masked_m, - k, - fp8_max, - fp8_min, - BLOCK=BLOCK, - NUM_STAGE=NUM_STAGES, - num_warps=num_warps, - SCALE_UE8M0=scale_ue8m0, - ) - output_scale = output_scale.transpose(1, 2)[:, :m, :] - check_sf_layout( - output_scale, - m, - k, - (1, 128), - g, - tma_stride_check=True, - ) - return output, output_scale - - -@triton.jit -def _per_token_quant_and_transform_kernel( - input_ptr, - stride_input_0, - stride_input_1, - output_ptr, - stride_output_0, - stride_output_1, - output_scale_ptr, - stride_output_scale_0, - stride_output_scale_1, - token_num_cur_expert, - size_k, - fp8_max, - fp8_min, - BLOCK: tl.constexpr, - NUM_STAGE: tl.constexpr, - SCALE_UE8M0: tl.constexpr, -): - tl.program_id(2) - token_id = tl.program_id(1) - hidden_dim_block_index = tl.program_id(0) - - block_num_per_expert = tl.num_programs(1) - - stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) - stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) - stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) - stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) - - offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) - input_ptr_offs = input_ptr + offs_in_d - output_ptr_offs = output_ptr + offs_in_d - output_scale_offs = (output_scale_ptr + - hidden_dim_block_index * stride_output_scale_0) - - for token_index in tl.range(token_id, - token_num_cur_expert, - block_num_per_expert, - num_stages=NUM_STAGE): - output_s_int32 = 0 - for pack_index in tl.range(4): - local_mask = offs_in_d + pack_index * 128 - act = tl.load( - input_ptr_offs + token_index * stride_input_0 + - pack_index * 128, - mask=local_mask < size_k, - other=0.0, - ).to(tl.float32) - _absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10) - output_s = _absmax / fp8_max - if SCALE_UE8M0: - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) - output_q = tl.clamp(act / output_s, fp8_min, - fp8_max).to(output_ptr.dtype.element_ty) - output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << - (8 * pack_index)) - tl.store( - output_ptr_offs + token_index * stride_output_0 + - pack_index * 128, - output_q, - mask=local_mask < size_k, - ) - tl.store( - output_scale_offs + token_index * stride_output_scale_1, - output_s_int32, - ) - - -def per_token_quant_and_transform( - input: torch.Tensor, - quant_group_size: int = 128, - scale_ue8m0: bool = True, -): - """ - input shape [g, m, k] - output shape [g, m, k // 2], dtype fp8 - output_scale [g, k // 4, m // 2 // 128], dtype int32 - quant_group_size int - masked_m shape [g] - """ - assert input.is_contiguous() - assert len(input.shape) == 2 - assert input.shape[-1] % 2 == 0 - - # FP8 quantization parameters finfo = torch.finfo(torch.float8_e4m3fn) fp8_max = finfo.max fp8_min = -fp8_max - m, k = input.shape - - # Create output - output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda") - - # Create output scale - alignment = 4 - scale_k = ceil_div(k, quant_group_size) - m_padded = align(m, alignment) - scale_k_padded = align(scale_k, alignment) - output_scale = torch.zeros((scale_k_padded // 4, m_padded), - dtype=torch.int32, - device='cuda') - - # Get block/grid/stage/warp - BLOCK_NUM_PER_EXPERT = 64 - - BLOCK = quant_group_size * 4 - num_warps = 1 - NUM_STAGES = 6 - hidden_dim_split_block_num = triton.cdiv(k, BLOCK) - grid = ( - hidden_dim_split_block_num, - BLOCK_NUM_PER_EXPERT, - 1, - ) - _per_token_quant_and_transform_kernel[grid]( + _silu_and_mul_post_quant_kernel[grid]( input, *input.stride(), output, *output.stride(), output_scale, *output_scale.stride(), - m, - k, + masked_m, + size_n, fp8_max, fp8_min, - BLOCK=BLOCK, + BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, num_warps=num_warps, SCALE_UE8M0=scale_ue8m0, ) - output_scale = output_scale.transpose(0, 1)[:m, :] - check_sf_layout( - output_scale, - m, - k, - (1, 128), - num_groups=None, - tma_stride_check=True, - ) - return output, output_scale + return From 59b3957220865f14fbfac189269d72b70cb53208 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Mon, 28 Jul 2025 18:29:33 +0800 Subject: [PATCH 35/38] Fix CI install error for DeepGEMM. (#28) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6ea0c9af4d8..f6d7c2e4718 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/yuxianq/DeepGEMM.git@417c5924b0a2a9410b4a1368f06f63a195081911 +deep_gemm @ git+https://github.com/deepseek-ai/DeepGEMM.git@dd6ed14acbc7445dcef224248a77ab4d22b5f240 From 3c413be89c0c11a98032e96620485860983a1df4 Mon Sep 17 00:00:00 2001 From: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Date: Mon, 28 Jul 2025 22:27:33 +0800 Subject: [PATCH 36/38] Reapply "Fuse quantize and transform e8m0 scales (#26)" (#27) (#29) * Reapply "Fuse quantize and transform e8m0 scales (#26)" (#27) This reverts commit 9107cfaa3a3af52b755c05ad1b29474520f7cb6d. * Remove compile for reducing warnings Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --------- Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 25 +- tensorrt_llm/_torch/modules/linear.py | 8 +- tensorrt_llm/quantization/utils/fp8_utils.py | 272 ++++++++++++++---- 3 files changed, 223 insertions(+), 82 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index c7bc7b4f2a1..3721a5d2afd 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -280,13 +280,6 @@ def deepgemm_fp8_group_blockwise_gemm( assert d.stride(-1) == 1 # Transform SFA and SFB into compute-required layout - recipe = (1, 128, 128) - sfa = fp8_utils.transform_sf_into_required_layout(sfa, - mn=m, - k=k, - recipe=recipe, - num_groups=num_groups, - is_sfa=True) deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), d, @@ -453,22 +446,8 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - act_input_fp8 = torch.empty(h1.shape[0], - h1.shape[1], - h1.shape[2] // 2, - dtype=torch.float8_e4m3fn, - device='cuda') - act_input_sf = torch.empty(h1.shape[0], - h1.shape[1], - h1.shape[2] // 256, - dtype=torch.float32, - device='cuda') - fp8_utils.silu_and_mul_masked_post_quant_fwd(input=h1, - output=act_input_fp8, - output_scale=act_input_sf, - quant_group_size=128, - masked_m=masked_m, - scale_ue8m0=True) + act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) h3 = deepgemm_fp8_group_blockwise_gemm( a=act_input_fp8, b=self.w2_weight, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index b4dfb1d6456..1131fe42275 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -574,13 +574,7 @@ def apply(self, module: Linear, input: torch.Tensor, if get_sm_version() == 100: import deep_gemm - a, a_sf = fp8_utils.per_token_cast_to_fp8_e8m0(input) - a_sf = fp8_utils.transform_sf_into_required_layout(a_sf, - mn=a.shape[0], - k=a.shape[1], - recipe=(1, 128, - 128), - is_sfa=True) + a, a_sf = fp8_utils.per_token_quant_and_transform(input) output = torch.empty((input.shape[0], module.weight.shape[0]), device=input.device, dtype=torch.bfloat16) diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 5e6c00b996b..19bd24671dd 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -166,7 +166,6 @@ def check_sf_layout(sf: torch.Tensor, @nvtx_range("[DG] transform_sf_into_required_layout") -@torch.compile(dynamic=True) def transform_sf_into_required_layout(sf: torch.Tensor, mn: int, k: int, @@ -234,10 +233,10 @@ def _silu_and_mul_post_quant_kernel( stride_output_scale_1, stride_output_scale_2, masked_m_ptr, - size_n, + size_k, fp8_max, fp8_min, - BLOCK_N: tl.constexpr, + BLOCK: tl.constexpr, NUM_STAGE: tl.constexpr, SCALE_UE8M0: tl.constexpr, ): @@ -254,109 +253,278 @@ def _silu_and_mul_post_quant_kernel( stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) - offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) + offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 + - hidden_dim_block_index * stride_output_scale_2) + hidden_dim_block_index * stride_output_scale_1) for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): - up = tl.load( - input_ptr_offs + token_index * stride_input_1, - mask=offs_in_d < size_n, - other=0.0, - ) - gate = tl.load( - input_ptr_offs + token_index * stride_input_1 + size_n, - mask=offs_in_d < size_n, - other=0.0, - ).to(tl.float32) - gate = gate / (1 + tl.exp(-gate)) - gate = gate.to(input_ptr.dtype.element_ty) - gate_up = up * gate - _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) - output_s = _absmax / fp8_max - if SCALE_UE8M0: - output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) - output_q = tl.clamp(gate_up / output_s, fp8_min, - fp8_max).to(output_ptr.dtype.element_ty) - tl.store( - output_ptr_offs + token_index * stride_output_1, - output_q, - mask=offs_in_d < size_n, - ) + output_s_int32 = 0 + for pack_index in tl.range(4): + local_mask = offs_in_d + pack_index * 128 + up = tl.load( + input_ptr_offs + token_index * stride_input_1 + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ) + gate = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_k + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ).to(tl.float32) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(gate_up / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << + (8 * pack_index)) + tl.store( + output_ptr_offs + token_index * stride_output_1 + + pack_index * 128, + output_q, + mask=local_mask < size_k, + ) tl.store( - output_scale_offs + token_index * stride_output_scale_1, - output_s, + output_scale_offs + token_index * stride_output_scale_2, + output_s_int32, ) def silu_and_mul_masked_post_quant_fwd( input: torch.Tensor, - output: torch.Tensor, - output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, scale_ue8m0: bool = False, ): """ - input shape [expert_num, token_num_padded, hidden_dim] - output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 - output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 - quant_group_size int, - masked_m shape [expert_num], + input shape [g, m, k] + output shape [g, m, k // 2], dtype fp8 + output_scale [g, k // 4, m // 2 // 128], dtype int32 + quant_group_size int + masked_m shape [g] """ assert input.is_contiguous() - assert output.dtype == torch.float8_e4m3fn - assert output.is_contiguous() assert len(input.shape) == 3 assert input.shape[0] == masked_m.shape[0] assert input.shape[-1] % 2 == 0 - size_n = input.shape[-1] // 2 - assert size_n % quant_group_size == 0 + # FP8 quantization parameters + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = finfo.min + + g, m, k = input.shape + k = k // 2 + + # Create output + output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") + # Create output scale + alignment = 4 + scale_k = ceil_div(k, quant_group_size) + m_padded = align(m, alignment) + scale_k_padded = align(scale_k, alignment) + output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), + dtype=torch.int32, + device='cuda') + + # Get block/grid/stage/warp expert_num = len(masked_m) if expert_num < 4: BLOCK_NUM_PER_EXPERT = 64 else: - BLOCK_NUM_PER_EXPERT = 32 + BLOCK_NUM_PER_EXPERT = 128 - BLOCK_N = quant_group_size + BLOCK = quant_group_size * 4 num_warps = 1 NUM_STAGES = 6 - hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) - assert BLOCK_N % quant_group_size == 0 - + hidden_dim_split_block_num = triton.cdiv(k, BLOCK) grid = ( hidden_dim_split_block_num, BLOCK_NUM_PER_EXPERT, expert_num, ) + _silu_and_mul_post_quant_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + output_scale, + *output_scale.stride(), + masked_m, + k, + fp8_max, + fp8_min, + BLOCK=BLOCK, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + SCALE_UE8M0=scale_ue8m0, + ) + output_scale = output_scale.transpose(1, 2)[:, :m, :] + check_sf_layout( + output_scale, + m, + k, + (1, 128), + g, + tma_stride_check=True, + ) + return output, output_scale + +@triton.jit +def _per_token_quant_and_transform_kernel( + input_ptr, + stride_input_0, + stride_input_1, + output_ptr, + stride_output_0, + stride_output_1, + output_scale_ptr, + stride_output_scale_0, + stride_output_scale_1, + token_num_cur_expert, + size_k, + fp8_max, + fp8_min, + BLOCK: tl.constexpr, + NUM_STAGE: tl.constexpr, + SCALE_UE8M0: tl.constexpr, +): + tl.program_id(2) + token_id = tl.program_id(1) + hidden_dim_block_index = tl.program_id(0) + + block_num_per_expert = tl.num_programs(1) + + stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) + stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) + stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) + stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) + + offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) + input_ptr_offs = input_ptr + offs_in_d + output_ptr_offs = output_ptr + offs_in_d + output_scale_offs = (output_scale_ptr + + hidden_dim_block_index * stride_output_scale_0) + + for token_index in tl.range(token_id, + token_num_cur_expert, + block_num_per_expert, + num_stages=NUM_STAGE): + output_s_int32 = 0 + for pack_index in tl.range(4): + local_mask = offs_in_d + pack_index * 128 + act = tl.load( + input_ptr_offs + token_index * stride_input_0 + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(act / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << + (8 * pack_index)) + tl.store( + output_ptr_offs + token_index * stride_output_0 + + pack_index * 128, + output_q, + mask=local_mask < size_k, + ) + tl.store( + output_scale_offs + token_index * stride_output_scale_1, + output_s_int32, + ) + + +def per_token_quant_and_transform( + input: torch.Tensor, + quant_group_size: int = 128, + scale_ue8m0: bool = True, +): + """ + input shape [g, m, k] + output shape [g, m, k // 2], dtype fp8 + output_scale [g, k // 4, m // 2 // 128], dtype int32 + quant_group_size int + masked_m shape [g] + """ + + assert input.is_contiguous() + assert len(input.shape) == 2 + assert input.shape[-1] % 2 == 0 + + # FP8 quantization parameters finfo = torch.finfo(torch.float8_e4m3fn) fp8_max = finfo.max fp8_min = -fp8_max - _silu_and_mul_post_quant_kernel[grid]( + m, k = input.shape + + # Create output + output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda") + + # Create output scale + alignment = 4 + scale_k = ceil_div(k, quant_group_size) + m_padded = align(m, alignment) + scale_k_padded = align(scale_k, alignment) + output_scale = torch.zeros((scale_k_padded // 4, m_padded), + dtype=torch.int32, + device='cuda') + + # Get block/grid/stage/warp + BLOCK_NUM_PER_EXPERT = 64 + + BLOCK = quant_group_size * 4 + num_warps = 1 + NUM_STAGES = 6 + hidden_dim_split_block_num = triton.cdiv(k, BLOCK) + grid = ( + hidden_dim_split_block_num, + BLOCK_NUM_PER_EXPERT, + 1, + ) + _per_token_quant_and_transform_kernel[grid]( input, *input.stride(), output, *output.stride(), output_scale, *output_scale.stride(), - masked_m, - size_n, + m, + k, fp8_max, fp8_min, - BLOCK_N=BLOCK_N, + BLOCK=BLOCK, NUM_STAGE=NUM_STAGES, num_warps=num_warps, SCALE_UE8M0=scale_ue8m0, ) - return + output_scale = output_scale.transpose(0, 1)[:m, :] + check_sf_layout( + output_scale, + m, + k, + (1, 128), + num_groups=None, + tma_stride_check=True, + ) + return output, output_scale From 2e9fcbe196aab07f8b2e6d920d091f2cf9892136 Mon Sep 17 00:00:00 2001 From: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Date: Wed, 30 Jul 2025 06:52:16 -0700 Subject: [PATCH 37/38] Fix UT for DeepGEMM Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> --- tests/unittest/_torch/modules/test_fused_moe.py | 13 +++++++++++-- tests/unittest/test_pip_install.py | 3 +++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 32adc513249..51a7758d281 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -421,6 +421,8 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype, router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() weights = {} + w3_w1_weight_scales = [] + w2_weight_scales = [] for expert_id in range(NUM_EXPERTS): w1_weight = torch.randn( (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE @@ -451,6 +453,13 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype, weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + w3_w1_weight_scales.append( + torch.cat([w3_weight_scale, w1_weight_scale], dim=0)) + w2_weight_scales.append(w2_weight_scale) + + w3_w1_weight_scales = torch.stack(w3_w1_weight_scales, dim=0).cuda() + w2_weight_scales = torch.stack(w2_weight_scales, dim=0).cuda() + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) fused_moe = DeepGemmFusedMoE( @@ -525,7 +534,7 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, a=act_input_fp8, b=fused_moe.w3_w1_weight, a_sf=act_input_sf, - b_sf=fused_moe.quant_scales[0], + b_sf=w3_w1_weight_scales, offset_array=expert_first_token_offset_tensor, ) h2 = swiglu_fused_moe(h1) @@ -534,7 +543,7 @@ def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, a=act_input_fp8, b=fused_moe.w2_weight, a_sf=act_input_sf, - b_sf=fused_moe.quant_scales[1], + b_sf=w2_weight_scales, offset_array=expert_first_token_offset_tensor, ) ref_output = torch.zeros_like(x) diff --git a/tests/unittest/test_pip_install.py b/tests/unittest/test_pip_install.py index 11288e09cec..d75bfbaf420 100644 --- a/tests/unittest/test_pip_install.py +++ b/tests/unittest/test_pip_install.py @@ -51,6 +51,9 @@ def test_pip_install(): help="The wheel path") args = parser.parse_args() + if not os.environ.get("CUDA_HOME"): + os.environ["CUDA_HOME"] = "/usr/local/cuda" + print("########## Install required system libs ##########") if not os.path.exists("/usr/local/mpi/bin/mpicc"): subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True) From 10bfbb57a9af872c052e503bbed3494a39b7dae5 Mon Sep 17 00:00:00 2001 From: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Date: Wed, 30 Jul 2025 22:40:37 -0700 Subject: [PATCH 38/38] Fix sanity check for deepgemm Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f6d7c2e4718..b87773c2d46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,4 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/deepseek-ai/DeepGEMM.git@dd6ed14acbc7445dcef224248a77ab4d22b5f240 +deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a