diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index de4c9b23a87..9ee456ce908 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -50,7 +50,10 @@ def add_llm_args(parser): parser.add_argument('--moe_backend', type=str, default='CUTLASS', - choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP']) + choices=[ + 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP', + 'DEEPGEMM', 'CUTEDSL' + ]) parser.add_argument('--enable_attention_dp', default=False, action='store_true') diff --git a/requirements.txt b/requirements.txt index 1e0584d9b37..59d522503c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,4 +61,5 @@ etcd3 blake3 llguidance==0.7.29 soundfile +deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a triton==3.3.1 diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 1277e25b42f..ba4703875e6 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -12,7 +12,8 @@ BaseWeightLoader from tensorrt_llm._torch.models.modeling_utils import ( register_checkpoint_weight_loader, run_concurrently) -from tensorrt_llm._utils import local_mpi_rank, local_mpi_size +from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank, + local_mpi_size) from tensorrt_llm.logger import logger @@ -38,6 +39,8 @@ def load_weights(self, checkpoint_dir: str) -> dict[str, Any]: f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files." ) self.prefetch_files(weight_files) + # Ensure that all local ranks have finished prefetching before loading weights + local_mpi_barrier() return self._load_weights_in_parallel( weight_files, self._load_safetensors_file, diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 7340b2c73c2..7ba59d2a635 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -38,12 +38,15 @@ from tqdm import tqdm from transformers import PretrainedConfig +from tensorrt_llm import logger from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.utils.fp8_utils import ( + resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams @@ -1244,7 +1247,7 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str, weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size local_num_heads = num_heads // weight_divisor - k_nope_weight_trans = k_nope_weight.transpose(2, 1) + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() kv_b_proj = torch.concat([ k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, @@ -1256,11 +1259,6 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str, return kv_b_proj, k_nope_weight_trans - def check_weight_dtype(module_name: str, dtype): - weight_name = "weight" - w_dtype = weights[f"{module_name}.{weight_name}"].dtype - return w_dtype == dtype - def load_kv_b_proj_and_k_b_proj_trans_dequant( module_name: str) -> torch.Tensor: weight_name = "weight" @@ -1290,7 +1288,7 @@ def load_kv_b_proj_and_k_b_proj_trans_dequant( weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size local_num_heads = num_heads // weight_divisor - k_nope_weight_trans = k_nope_weight.transpose(2, 1) + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() kv_b_proj = torch.concat([ k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, @@ -1333,6 +1331,21 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.named_modules()) + if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( + ) and get_sm_version() == 100: + for name in list(weights.keys()): + # Use ".experts." to exclude shared_experts. + if name.endswith( + "weight_scale_inv") and ".experts." not in name: + weight_name = name.replace("weight_scale_inv", "weight") + logger.debug(f"Resmoothing {weight_name}") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + weights[weight_name] = weights[weight_name].cpu() + weights[name] = weights[name].cpu() + for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): if len(module._parameters) > 0: @@ -1384,6 +1397,26 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, attn_module.v_b_proj_scale = nn.Parameter( v_b_proj_scale, requires_grad=False) + if attn_module.k_b_proj_trans_dequant is not None: + attn_module.k_b_proj_trans_dequant.data.copy_( + weight_dequant( + k_b_proj_trans.view( + -1, k_b_proj_trans.shape[-1]).cuda(), + k_b_proj_trans_scale.view( + -1, + k_b_proj_trans_scale.shape[-1]).cuda(), + ).view( + *attn_module.k_b_proj_trans_dequant.shape). + to(attn_module.k_b_proj_trans_dequant.dtype)) + if attn_module.v_b_proj_dequant is not None: + attn_module.v_b_proj_dequant.data.copy_( + weight_dequant( + v_b_proj.view(-1, + v_b_proj.shape[-1]).cuda(), + v_b_proj_scale.view( + -1, v_b_proj_scale.shape[-1]).cuda(), + ).view(*attn_module.v_b_proj_dequant.shape).to( + attn_module.v_b_proj_dequant.dtype)) elif names[-1] == "kv_a_proj_with_mqa": fused_a = weights[ f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] @@ -1431,6 +1464,18 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, for n, p in module.named_parameters(): p.data.copy_(module_weights[n][:]) + if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales( + ) and get_sm_version() == 100 and hasattr( + module, "weight_scale"): + transfromed_scale = transform_sf_into_required_layout( + module.weight_scale, + mn=module.weight.shape[0], + k=module.weight.shape[1], + recipe=(1, 128, 128), + is_sfa=False) + module.weight_scale = nn.Parameter(transfromed_scale, + requires_grad=False) + for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index f9e04a2b5ad..c24513e25fb 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -357,6 +357,7 @@ def fp8_block_scaling_bmm_out( mat2_fp8: torch.Tensor, mat2_scale: torch.Tensor, out: torch.Tensor, + mat2_dequant: Optional[torch.Tensor] = None, ) -> torch.Tensor: sm_version = get_sm_version() if sm_version == 90 or sm_version == 89: @@ -365,30 +366,33 @@ def fp8_block_scaling_bmm_out( torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, out) elif sm_version == 100: - low_latency = True - use_deep_seek_fp8 = True - tile_size = 8 - epilogue_tile_m = 64 if use_deep_seek_fp8 else 128 - m_size = mat1.shape[0] - if m_size % tile_size != 0: - tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size - mat1 = torch.nn.functional.pad( - mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0) - - mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( - mat1) - output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen( - mat1_fp8, - mat2_fp8, - tile_size=tile_size, - epilogue_tile_m=epilogue_tile_m, - use_deep_seek_fp8=use_deep_seek_fp8, - low_latency=low_latency, - dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1), - dq_sfs_b=mat2_scale, - out_dtype=out.dtype, - ) - out.copy_(output[:, :m_size]) + output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2)) + out.copy_(output) + + # low_latency = True + # use_deep_seek_fp8 = True + # tile_size = 8 + # epilogue_tile_m = 64 if use_deep_seek_fp8 else 128 + # m_size = mat1.shape[0] + # if m_size % tile_size != 0: + # tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size + # mat1 = torch.nn.functional.pad( + # mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0) + + # mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( + # mat1) + # output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen( + # mat1_fp8, + # mat2_fp8, + # tile_size=tile_size, + # epilogue_tile_m=epilogue_tile_m, + # use_deep_seek_fp8=use_deep_seek_fp8, + # low_latency=low_latency, + # dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1), + # dq_sfs_b=mat2_scale, + # out_dtype=out.dtype, + # ) + # out.copy_(output[:, :m_size]) else: raise NotImplementedError(f"SM{sm_version} is not supported") @@ -676,6 +680,8 @@ def create_weights(self): requires_grad=False, ) + self.k_b_proj_trans_dequant = None + self.v_b_proj_dequant = None if has_fp8_block_scales: self.k_b_proj_trans_scale = nn.Parameter( torch.empty( @@ -701,6 +707,23 @@ def create_weights(self): ), requires_grad=False, ) + if get_sm_version() == 100: + assert self.dtype == torch.bfloat16 + self.k_b_proj_trans_dequant = nn.Parameter( + torch.empty( + (self.num_heads, self.kv_lora_rank, + self.qk_nope_head_dim), + dtype=self.dtype, + ), + requires_grad=False, + ) + self.v_b_proj_dequant = nn.Parameter( + torch.empty( + (self.num_heads, self.v_head_dim, self.kv_lora_rank), + dtype=self.dtype, + ), + requires_grad=False, + ) else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None @@ -1197,8 +1220,13 @@ def forward_generation( # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) - fp8_block_scaling_bmm_out(q_nope, self.k_b_proj_trans, - self.k_b_proj_trans_scale, q_nope_out) + fp8_block_scaling_bmm_out( + q_nope, + self.k_b_proj_trans, + self.k_b_proj_trans_scale, + q_nope_out, + self.k_b_proj_trans_dequant, + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") @@ -1247,9 +1275,13 @@ def forward_generation( self.v_b_proj.transpose(1, 2), attn_output.transpose(0, 1)) elif self.v_b_proj.dtype == torch.float8_e4m3fn: - fp8_block_scaling_bmm_out(attn_out_latent, self.v_b_proj, - self.v_b_proj_scale, - attn_output.transpose(0, 1)) + fp8_block_scaling_bmm_out( + attn_out_latent, + self.v_b_proj, + self.v_b_proj_scale, + attn_output.transpose(0, 1), + self.v_b_proj_dequant, + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.") diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 17f7e436b17..0b47e18f60d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -8,6 +8,7 @@ from ...model_config import ModelConfig from .fused_moe_cute_dsl import CuteDslFusedMoE from .fused_moe_cutlass import CutlassFusedMoE +from .fused_moe_deepgemm import DeepGemmFusedMoE from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE from .fused_moe_wide_ep import WideEPMoE @@ -31,6 +32,8 @@ def get_moe_cls( return VanillaMoE elif moe_backend.upper() == "CUTEDSL": return CuteDslFusedMoE + elif moe_backend.upper() == "DEEPGEMM": + return DeepGemmFusedMoE elif moe_backend.upper() == "TRTLLM": if quant_config is not None and ( quant_config.quant_mode.has_fp8_block_scales() @@ -139,5 +142,19 @@ def create_moe( apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, ) + elif moe_cls == DeepGemmFusedMoE: + return moe_cls( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream=aux_stream, + weight_loading_mode=weight_loading_mode, + apply_router_weight_on_input=apply_router_weight_on_input, + layer_idx=layer_idx, + ) else: raise ValueError(f"Unsupported moe backend: {moe_cls}") diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py new file mode 100644 index 00000000000..3721a5d2afd --- /dev/null +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -0,0 +1,483 @@ +from typing import List, Optional, Union + +import deep_gemm +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils +from tensorrt_llm._utils import nvtx_range + +from ...distributed import allgather +from ...model_config import ModelConfig +from ...utils import Fp4QuantizedTensor +from .fused_moe_cutlass import CutlassFusedMoE +from .quantization import MoEWeightLoadingMode +from .routing import BaseMoeRoutingMethod + + +@triton.jit +def _masked_index_copy_group_quant_fp8( + input_ptr, + out_q_ptr, + out_s_ptr, + # mask indices + start_offsets_ptr, + row_indices_ptr, + # dimensions + row_size, + col_size, + dim_size, + group_size, + # output scale factor size + aligned_col, + aligned_dim, + # quantization parameters + eps, + fp8_max, + # block size + BLOCK: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + group_block = tl.program_id(0) + token_block = tl.program_id(1) + token_block_num = tl.num_programs(1) + + # calculate group and element offsets + num_tokens = tl.load(start_offsets_ptr + row_size) + elem_offsets = group_block * group_size * 4 + tl.arange(0, BLOCK) + output_s_offs = out_s_ptr + group_block * aligned_col + + # process tokens + for token_index in tl.range(token_block, + num_tokens, + token_block_num, + num_stages=NUM_STAGE): + # load indices + row_idx = tl.load(row_indices_ptr + token_index) + start_offset = tl.load(start_offsets_ptr + row_idx) + idx = row_idx * col_size + token_index - start_offset + idx_s = row_idx * aligned_dim * aligned_col + token_index - start_offset + + output_s_int32 = 0 + for group_index in tl.range(4): + # load input data + dim_offset = elem_offsets + group_index * group_size + valid = dim_offset < dim_size + input_data = tl.load(input_ptr + token_index * dim_size + + dim_offset, + mask=valid, + other=0.0) + # quantization + _absmax = tl.maximum(tl.max(tl.abs(input_data)), eps) + output_s = _absmax / fp8_max + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(input_data / output_s, -fp8_max, + fp8_max).to(out_q_ptr.dtype.element_ty) + output_s = output_s.to(tl.int32, bitcast=True) >> 23 + output_s_int32 += output_s << (group_index * 8) + + # store quantized values and scaling factor + tl.store(out_q_ptr + idx * dim_size + dim_offset, + output_q, + mask=valid) + tl.store(output_s_offs + idx_s, output_s_int32) + + +def masked_index_copy_group_quant_fp8( + output: torch.Tensor, + input: torch.Tensor, + start_offsets: torch.Tensor, + row_indices: torch.Tensor, + group_size: int, + eps: float = 1e-10, +): + assert ( + input.shape[-1] % group_size == 0 + ), "the last dimension of `input` cannot be divisible by `group_size`" + assert input.is_contiguous(), "`input` is not contiguous" + assert input.ndim == 2, "Input must be a 2D tensor" + assert output.ndim == 3, "Output must be a 3D tensor, [row, col, dim]" + assert start_offsets.shape[ + 0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)" + + num_tokens = input.shape[0] + row_size = output.shape[0] + col_size = output.shape[1] + dim_size = output.shape[2] + + # create padded output_s + alignment = 4 + scale_dim = (dim_size + group_size - 1) // group_size + padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment + padded_col_size = (col_size + alignment - 1) // alignment * alignment + output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size), + dtype=torch.int32, + device='cuda') + + # get block/grid/stage/warp + num_groups = (dim_size + group_size - 1) // group_size + BLOCK = group_size + if num_tokens <= 1000 or col_size <= 256: # Small workload + TOKEN_BLOCK_NUM = 256 + NUM_STAGES = 4 + num_warps = 2 + elif num_tokens <= 10000 or col_size <= 2048: # Medium workload + TOKEN_BLOCK_NUM = 1024 + NUM_STAGES = 2 + num_warps = 1 + else: # Large workload + TOKEN_BLOCK_NUM = 2048 + NUM_STAGES = 2 + num_warps = 1 + grid = ( + (num_groups + 3) // 4, + TOKEN_BLOCK_NUM, + ) + + # FP8 quantization parameters + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + + _masked_index_copy_group_quant_fp8[grid]( + input, + output, + output_s, + start_offsets, + row_indices, + row_size, + col_size, + dim_size, + group_size, + padded_col_size, + padded_dim_size // 4, + eps, + fp8_max, + BLOCK=BLOCK, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + ) + output_s = output_s.transpose(1, 2)[:, :col_size, :] + return output_s + + +@triton.jit +def masked_index_gather_kernel(output_ptr, input_ptr, start_offsets_ptr, + row_indices_ptr, row_size, col_size, dim_size, + BLOCK_SIZE: tl.constexpr): + # get program id and block offset + pid = tl.program_id(0) + num_tokens = tl.load(start_offsets_ptr + row_size) + + token_idx = pid + valid_token = token_idx < num_tokens + if not valid_token: + return + + row_idx = tl.load(row_indices_ptr + token_idx) + start_offset = tl.load(start_offsets_ptr + row_idx) + col_idx = token_idx - start_offset + + # Process elements in blocks + for hidden_start in tl.range(0, dim_size, BLOCK_SIZE): + hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE) + valid_hidden = hidden_indices < dim_size + + input_offset = row_idx * col_size * dim_size + col_idx * dim_size + hidden_indices + input_val = tl.load(input_ptr + input_offset, + mask=valid_hidden, + other=0.0) + + output_offset = pid * dim_size + hidden_indices + tl.store(output_ptr + output_offset, input_val, mask=valid_hidden) + + +@torch.no_grad() +def triton_masked_index_gather(output, input, start_offsets, row_indices): + assert output.ndim == 2, "Output must be a 2D tensor" + assert input.ndim == 3, "Input must be a 3D tensor, [row, col, dim]" + assert start_offsets.shape[ + 0] == input.shape[0] + 1, "Start offsets must be (num_experts + 1)" + + row_size = input.shape[0] + col_size = input.shape[1] + dim_size = input.shape[2] + num_tokens = output.shape[0] + + grid = (num_tokens, ) + # launch kernel + masked_index_gather_kernel[grid](output, + input, + start_offsets, + row_indices, + row_size, + col_size, + dim_size, + BLOCK_SIZE=1024) + return + + +@nvtx_range("[DG] act") +@torch.compile(dynamic=True) +def swiglu_fused_moe(x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +@nvtx_range("[DG] indexing") +@torch.compile(dynamic=True) +def indexing(x, mask): + return x[mask > 0, :].contiguous() + + +@nvtx_range("[DG] preprocess_after_permute") +def preprocess_after_permute(expert_first_token_offset_tensor, + permuted_data_tensor): + # get tokens per expert + masked_m = expert_first_token_offset_tensor[ + 1:] - expert_first_token_offset_tensor[:-1] + token_to_expert_map = torch.searchsorted( + expert_first_token_offset_tensor[1:], + torch.arange(permuted_data_tensor.shape[0], device='cuda'), + right=True) + return masked_m.to(torch.int32), token_to_expert_map + + +@nvtx_range("[DG]") +def deepgemm_fp8_group_blockwise_gemm( + a: torch.Tensor, + b: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, +) -> torch.Tensor: + d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), + device=b.device, + dtype=torch.bfloat16) + + # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` + assert a.stride(-1) == 1 + assert b.stride(-1) == 1 + assert masked_m.is_contiguous() + + num_groups, m, k = a.shape + num_groups_, n, k_ = b.shape + num_groups__, m_, n_ = d.shape + num_groups___ = masked_m.numel() + + # Type and shape checks + assert num_groups == num_groups_ == num_groups__ == num_groups___ + assert m == m_ and n == n_ and k == k_ + assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + assert d.dtype == torch.bfloat16 + assert masked_m.dtype == torch.int32 + + # D must be N-major + assert d.stride(-1) == 1 + + # Transform SFA and SFB into compute-required layout + + deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb), + d, + masked_m, + expected_m, + disable_ue8m0_cast=True) + return d + + +class DeepGemmFusedMoE(CutlassFusedMoE): + """ + Python Flow of Fused Mixture of Experts (MoE) Layer. + + Args: + num_experts (int): Number of experts in the MoE layer. + top_k (int): Number of top experts to select for each input token. + hidden_size (int): Size of the hidden state. + intermediate_size (int): Size of the intermediate state. + aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks. + dtype (Optional[torch.dtype]): Data type for the weights. + reduce_results (bool): Whether to reduce the results across devices. + model_config (ModelConfig): Configuration object for the model. + + This backend is composed of multiple custom ops: + 1. moe_permute_op: permute the input tensor and the expert selected tensor. + 2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm. + 3. moe_finalize_scale_op: finalize the scale of the output tensor. + """ + + def __init__( + self, + *, + routing_method: BaseMoeRoutingMethod, + num_experts: int, + hidden_size: int, + intermediate_size: int, + dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + model_config: ModelConfig = ModelConfig(), + aux_stream: Optional[torch.cuda.Stream] = None, + weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode. + VANILLA, + apply_router_weight_on_input: bool = False, + layer_idx: Optional[int] = None, + ): + + super().__init__( + routing_method=routing_method, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + dtype=dtype, + reduce_results=reduce_results, + model_config=model_config, + aux_stream=aux_stream, + weight_loading_mode=weight_loading_mode, + apply_router_weight_on_input=apply_router_weight_on_input, + layer_idx=layer_idx, + ) + + @nvtx_range("[DG] forward") + def forward_chunk( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + if isinstance(x, Fp4QuantizedTensor): + assert output_dtype is not None + output_dtype = output_dtype + else: + output_dtype = x.dtype + + # apply routing + token_selected_experts, token_final_scales = self.routing_method.apply( + router_logits) + assert token_selected_experts.shape[ + 1] == self.routing_method.experts_per_token + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] + assert token_final_scales.dtype == torch.float32 + assert token_selected_experts.dtype == torch.int32 + + if self.apply_router_weight_on_input: + assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" + assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input" + x = x * token_final_scales.to(x.dtype) + # TODO: remove this once we have correct fusedmoe kernel ready + token_final_scales = None + + # quantize inputs + use_deepseek_fp8_block_scale = False + x_sf = None + if self.has_any_quant: + if self.has_deepseek_fp8_block_scales: + use_deepseek_fp8_block_scale = True + else: + raise ValueError( + f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}" + ) + + use_allgather = self.use_dp and self.parallel_size > 1 + if use_allgather: + x, x_sf, token_selected_experts, token_final_scales = allgather( + [x, x_sf, token_selected_experts, token_final_scales], + self.mapping, + dim=0, + sizes=None if use_dp_padding else all_rank_num_tokens) + + ( + permuted_row_to_unpermuted_row_tensor, + permuted_token_selected_experts_tensor, + permuted_data_tensor, + expert_first_token_offset_tensor, + permuted_token_final_scales_tensor, + unpermuted_row_to_permuted_row_tensor, + ) = torch.ops.trtllm.moe_permute_op( + x, + token_selected_experts, + token_final_scales, + None, # w3_w1_weight.view(weight_dtype), + None, # w2_weight.view(weight_dtype), + None, # quant_scales, + input_sf=x_sf, + num_experts_on_rank=self.expert_size_per_partition, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + cluster_size=self.cluster_size, + cluster_rank=self.cluster_rank, + min_latency_mode=False, + use_fp8_block_scaling=use_deepseek_fp8_block_scale, + ) + + if permuted_data_tensor.numel() == 0: + return torch.zeros_like(x) + + masked_m, token_to_expert_map = preprocess_after_permute( + expert_first_token_offset_tensor, permuted_data_tensor) + + m_max = (x.shape[0] + 127) // 128 * 128 + expected_m = (token_selected_experts.numel() + + self.expert_size_per_partition - + 1) // self.expert_size_per_partition + act_input_fp8 = torch.empty( + (self.expert_size_per_partition, m_max, self.hidden_size), + dtype=torch.float8_e4m3fn, + device='cuda') + act_input_sf = masked_index_copy_group_quant_fp8( + act_input_fp8, + permuted_data_tensor, + expert_first_token_offset_tensor, + token_to_expert_map, + group_size=128) + + h1 = deepgemm_fp8_group_blockwise_gemm( + a=act_input_fp8, + b=self.w3_w1_weight, + sfa=act_input_sf, + sfb=self.quant_scales[0], + masked_m=masked_m, + expected_m=expected_m, + ) + act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) + h3 = deepgemm_fp8_group_blockwise_gemm( + a=act_input_fp8, + b=self.w2_weight, + sfa=act_input_sf, + sfb=self.quant_scales[1], + masked_m=masked_m, + expected_m=expected_m, + ) + + triton_masked_index_gather(permuted_data_tensor, h3, + expert_first_token_offset_tensor, + token_to_expert_map) + + final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op( + permuted_data_tensor, + None, # biases + token_final_scales, + unpermuted_row_to_permuted_row_tensor, + permuted_row_to_unpermuted_row_tensor, + token_selected_experts, + expert_first_token_offset_tensor, + False, # enable_alltoall + x.shape[0], # num_rows + x.shape[1], # hidden_size + self.routing_method.top_k, + self.expert_size_per_partition, # num_experts_per_node + self.tp_size, + self.tp_rank, + self.ep_size, + self.ep_rank, + ) + + return final_hidden_states diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index f957712e3e5..18e9c7cc98a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -4,10 +4,13 @@ import torch from torch import nn +from tensorrt_llm import logger from tensorrt_llm._utils import get_sm_version from tensorrt_llm.quantization.utils.fp4_utils import ( float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices) +from tensorrt_llm.quantization.utils.fp8_utils import ( + resmooth_to_fp8_e8m0, transform_sf_into_required_layout) from ..linear import TensorParallelMode, load_weight_shard from .interface import MoEWeightLoadingMode @@ -463,6 +466,47 @@ def create_weights(self, module: torch.nn.Module): self.setup_quant_scales(module) + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode): + + if get_sm_version() == 100: + expert_ids = set(module.initial_local_expert_ids) + if self.need_load_shared_weights(module): + expert_ids.update( + module.layer_load_balancer.get_load_expert_ids()) + for name in list(weights.keys()): + if name.endswith("weight_scale_inv"): + if int(name.split(".")[0]) not in expert_ids: + continue + weight_name = name.replace("weight_scale_inv", "weight") + logger.debug(f"Resmoothing {weight_name}") + weight = weights[weight_name][:] + scale = weights[name][:] + weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( + weight, scale) + super().load_weights(module, weights, weight_loading_mode) + + if get_sm_version() == 100: + transfromed_w3_w1_scale = transform_sf_into_required_layout( + module.quant_scales[0], + mn=module.w3_w1_weight.shape[1], + k=module.w3_w1_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w3_w1_weight_scaling_factor = nn.Parameter( + transfromed_w3_w1_scale, requires_grad=False) + transfromed_w2_scale = transform_sf_into_required_layout( + module.quant_scales[1], + mn=module.w2_weight.shape[1], + k=module.w2_weight.shape[2], + recipe=(1, 128, 128), + num_groups=module.w3_w1_weight.shape[0], + is_sfa=False) + module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale, + requires_grad=False) + self.setup_quant_scales(module) + def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales( fc_weight_scales=module.w3_w1_weight_scaling_factor, diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 1ef5be24c8b..1131fe42275 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils +import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils from tensorrt_llm._torch.peft.lora.layer import LoraLayer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy) @@ -20,6 +21,7 @@ preprocess_weights_for_mixed_gemm from tensorrt_llm.quantization.mode import QuantAlgo +from ..._utils import get_sm_version from ...models.modeling_utils import QuantConfig from ..utils import Fp4QuantizedTensor @@ -570,10 +572,22 @@ def apply(self, module: Linear, input: torch.Tensor, input = input.to(torch.bfloat16) * module.input_scale assert input.dtype == torch.bfloat16 - act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(input) + if get_sm_version() == 100: + import deep_gemm + a, a_sf = fp8_utils.per_token_quant_and_transform(input) + output = torch.empty((input.shape[0], module.weight.shape[0]), + device=input.device, + dtype=torch.bfloat16) + deep_gemm.fp8_gemm_nt((a, a_sf), + (module.weight, module.weight_scale), + output, + disable_ue8m0_cast=True) + else: + act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( + input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( - act_input_fp8, module.weight, act_input_sf, module.weight_scale) + output = torch.ops.trtllm.fp8_block_scaling_gemm( + act_input_fp8, module.weight, act_input_sf, module.weight_scale) if bias is not None: output = output + bias return output @@ -603,7 +617,6 @@ def load_weights_fused_qkv_linear(self, module: Linear, q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( module, weights) fused_weight = torch.cat((q_weight, k_weight, v_weight)) - copy_weight(module.weight, fused_weight) scale_name = self._get_scale_name(weights) q_scale = load_weight_shard(weights[0][scale_name], module.tp_size, @@ -614,6 +627,7 @@ def load_weights_fused_qkv_linear(self, module: Linear, module.tp_rank, module.tp_mode) fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze() + copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_fp8_block_scale) def load_weights_fused_gate_up_linear(self, module: Linear, @@ -621,7 +635,6 @@ def load_weights_fused_gate_up_linear(self, module: Linear, gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) fused_weight = torch.cat((gate_weight, up_weight)) - copy_weight(module.weight, fused_weight) scale_name = self._get_scale_name(weights) left_scale = load_weight_shard(weights[0][scale_name], module.tp_size, @@ -629,6 +642,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear, right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, module.tp_rank, module.tp_mode) fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze() + copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_scale) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 88b86b03f54..260946578a1 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -144,6 +144,8 @@ def _create_dummy_context_requests( end_id=-1) requests.append(request) remaining_tokens -= input_seq_len + if self._mapping.enable_attention_dp: + requests = requests * self._mapping.tp_size return requests def _get_token_num_for_estimation(self) -> int: diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index b07430224af..e733a0331f6 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -470,6 +470,10 @@ def mpi_comm(): local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST) +def local_mpi_comm(): + return local_comm + + def mpi_rank(): return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0 @@ -508,6 +512,11 @@ def mpi_barrier(): mpi_comm().Barrier() +def local_mpi_barrier(): + if ENABLE_MULTI_DEVICE: + local_comm.Barrier() + + def mpi_broadcast(obj, root=0): return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 44016980fc4..1c836264e22 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -167,7 +167,7 @@ class MoeConfig(StrictBaseModel): """ Configuration for MoE. """ - backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", + backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM", "VANILLA"] = Field(default='CUTLASS', description="MoE backend to use.") diff --git a/tensorrt_llm/quantization/utils/__init__.py b/tensorrt_llm/quantization/utils/__init__.py index a0cc798c42f..a79df9ebcb2 100644 --- a/tensorrt_llm/quantization/utils/__init__.py +++ b/tensorrt_llm/quantization/utils/__init__.py @@ -1,3 +1,3 @@ -from . import fp4_utils +from . import fp4_utils, fp8_utils -__all__ = ['fp4_utils'] +__all__ = ['fp4_utils', 'fp8_utils'] diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py new file mode 100644 index 00000000000..19bd24671dd --- /dev/null +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -0,0 +1,530 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from tensorrt_llm._utils import nvtx_range + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +@nvtx_range("[DG] quantization") +@torch.compile(dynamic=True) +def per_token_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if x.dim() == 2: + assert x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n), sf + else: + assert x.size(2) % 128 == 0 + g, m, n = x.shape + x_view = x.view(g, m, -1, 128) + x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view( + g, m, n), sf + + +def per_block_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if x.dim() == 2: + m, n = x.shape + x_padded = torch.zeros((align(m, 128), align(n, 128)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2)) + else: + g, m, n = x.shape + x_padded = torch.zeros((g, align(m, 128), align(n, 128)), + dtype=x.dtype, + device=x.device) + x_padded[:, :m, :n] = x + x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(1), x_view.size(3)) + + +def resmooth_to_fp8_e8m0(weight: torch.Tensor, + sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weight = weight.cuda() + sf = sf.cuda() + if weight.dim() == 2: + x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave( + 128, dim=1)[:weight.shape[0], :weight.shape[1]] + else: + x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave( + 128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]] + return per_block_cast_to_fp8_e8m0(x) + + +def get_m_alignment_for_contiguous_layout(): + return 128 + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return align(x, alignment) + + +def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), + device=x.device, + dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, + aligned_k // 4) + + # Finally, transpose + transposed = torch.transpose( + torch.empty((b, aligned_k // 4, aligned_mn), + device=x.device, + dtype=torch.int), 1, 2) + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def check_sf_layout(sf: torch.Tensor, + mn: int, + k: int, + gran: Tuple[int, int], + num_groups: Optional[int], + tma_stride_check: bool = False, + type_check: Optional[torch.dtype] = None) -> torch.Tensor: + # Type check + if type_check is not None: + assert sf.dtype == type_check + + # Always do shape checks + assert sf.dtype in (torch.float, torch.int) + assert sf.dim() == int(num_groups is not None) + 2 + if num_groups is not None: + assert sf.size(-3) == num_groups + assert sf.size(-2) == ceil_div(mn, gran[0]) + assert sf.size(-1) == ceil_div( + k, gran[1] * (1 if sf.dtype == torch.float else 4)) + + # TMA stride checks: TMA aligned and MN-major + if tma_stride_check: + if num_groups is not None: + assert sf.stride(-3) == sf.stride(-1) * sf.size(-1) + assert sf.stride(-2) == 1 + assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()) + + return sf + + +@nvtx_range("[DG] transform_sf_into_required_layout") +def transform_sf_into_required_layout(sf: torch.Tensor, + mn: int, + k: int, + recipe: Tuple[int, int, int], + num_groups: Optional[int] = None, + is_sfa: bool = False): + gran = (recipe[0 if is_sfa else 1], recipe[2]) + + should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128)) + or (sf.dtype == torch.int and gran == (128, 128))) + + if not should_skip_transform: + # Pre-transform checks + check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups) + + # (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if sf.dtype == torch.float and gran == (1, 128): + sf = get_col_major_tma_aligned_packed_tensor(sf) + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + # (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if sf.dtype == torch.float and gran == (128, 128): + sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) + sf = get_col_major_tma_aligned_packed_tensor(sf) + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + if should_skip_transform: + # TODO: add transpose kernel if SF layout is not satisfied + return check_sf_layout(sf, + mn=mn, + k=k, + gran=(1, 128), + num_groups=num_groups, + tma_stride_check=True, + type_check=torch.int) + + assert False, f'Unknown cases: {sf.dtype=}, {gran=}' + + +# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py +@triton.jit +def _silu_and_mul_post_quant_kernel( + input_ptr, + stride_input_0, + stride_input_1, + stride_input_2, + output_ptr, + stride_output_0, + stride_output_1, + stride_output_2, + output_scale_ptr, + stride_output_scale_0, + stride_output_scale_1, + stride_output_scale_2, + masked_m_ptr, + size_k, + fp8_max, + fp8_min, + BLOCK: tl.constexpr, + NUM_STAGE: tl.constexpr, + SCALE_UE8M0: tl.constexpr, +): + expert_id = tl.program_id(2) + token_id = tl.program_id(1) + hidden_dim_block_index = tl.program_id(0) + + block_num_per_expert = tl.num_programs(1) + + token_num_cur_expert = tl.load(masked_m_ptr + expert_id) + + stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) + stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) + stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) + stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) + + offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) + input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d + output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d + output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 + + hidden_dim_block_index * stride_output_scale_1) + + for token_index in tl.range(token_id, + token_num_cur_expert, + block_num_per_expert, + num_stages=NUM_STAGE): + output_s_int32 = 0 + for pack_index in tl.range(4): + local_mask = offs_in_d + pack_index * 128 + up = tl.load( + input_ptr_offs + token_index * stride_input_1 + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ) + gate = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_k + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ).to(tl.float32) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(gate_up / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << + (8 * pack_index)) + tl.store( + output_ptr_offs + token_index * stride_output_1 + + pack_index * 128, + output_q, + mask=local_mask < size_k, + ) + tl.store( + output_scale_offs + token_index * stride_output_scale_2, + output_s_int32, + ) + + +def silu_and_mul_masked_post_quant_fwd( + input: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + scale_ue8m0: bool = False, +): + """ + input shape [g, m, k] + output shape [g, m, k // 2], dtype fp8 + output_scale [g, k // 4, m // 2 // 128], dtype int32 + quant_group_size int + masked_m shape [g] + """ + + assert input.is_contiguous() + assert len(input.shape) == 3 + assert input.shape[0] == masked_m.shape[0] + assert input.shape[-1] % 2 == 0 + + # FP8 quantization parameters + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = finfo.min + + g, m, k = input.shape + k = k // 2 + + # Create output + output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") + + # Create output scale + alignment = 4 + scale_k = ceil_div(k, quant_group_size) + m_padded = align(m, alignment) + scale_k_padded = align(scale_k, alignment) + output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), + dtype=torch.int32, + device='cuda') + + # Get block/grid/stage/warp + expert_num = len(masked_m) + + if expert_num < 4: + BLOCK_NUM_PER_EXPERT = 64 + else: + BLOCK_NUM_PER_EXPERT = 128 + + BLOCK = quant_group_size * 4 + num_warps = 1 + NUM_STAGES = 6 + hidden_dim_split_block_num = triton.cdiv(k, BLOCK) + grid = ( + hidden_dim_split_block_num, + BLOCK_NUM_PER_EXPERT, + expert_num, + ) + _silu_and_mul_post_quant_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + output_scale, + *output_scale.stride(), + masked_m, + k, + fp8_max, + fp8_min, + BLOCK=BLOCK, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + SCALE_UE8M0=scale_ue8m0, + ) + output_scale = output_scale.transpose(1, 2)[:, :m, :] + check_sf_layout( + output_scale, + m, + k, + (1, 128), + g, + tma_stride_check=True, + ) + return output, output_scale + + +@triton.jit +def _per_token_quant_and_transform_kernel( + input_ptr, + stride_input_0, + stride_input_1, + output_ptr, + stride_output_0, + stride_output_1, + output_scale_ptr, + stride_output_scale_0, + stride_output_scale_1, + token_num_cur_expert, + size_k, + fp8_max, + fp8_min, + BLOCK: tl.constexpr, + NUM_STAGE: tl.constexpr, + SCALE_UE8M0: tl.constexpr, +): + tl.program_id(2) + token_id = tl.program_id(1) + hidden_dim_block_index = tl.program_id(0) + + block_num_per_expert = tl.num_programs(1) + + stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) + stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) + stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) + stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) + + offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4) + input_ptr_offs = input_ptr + offs_in_d + output_ptr_offs = output_ptr + offs_in_d + output_scale_offs = (output_scale_ptr + + hidden_dim_block_index * stride_output_scale_0) + + for token_index in tl.range(token_id, + token_num_cur_expert, + block_num_per_expert, + num_stages=NUM_STAGE): + output_s_int32 = 0 + for pack_index in tl.range(4): + local_mask = offs_in_d + pack_index * 128 + act = tl.load( + input_ptr_offs + token_index * stride_input_0 + + pack_index * 128, + mask=local_mask < size_k, + other=0.0, + ).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10) + output_s = _absmax / fp8_max + if SCALE_UE8M0: + output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) + output_q = tl.clamp(act / output_s, fp8_min, + fp8_max).to(output_ptr.dtype.element_ty) + output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) << + (8 * pack_index)) + tl.store( + output_ptr_offs + token_index * stride_output_0 + + pack_index * 128, + output_q, + mask=local_mask < size_k, + ) + tl.store( + output_scale_offs + token_index * stride_output_scale_1, + output_s_int32, + ) + + +def per_token_quant_and_transform( + input: torch.Tensor, + quant_group_size: int = 128, + scale_ue8m0: bool = True, +): + """ + input shape [g, m, k] + output shape [g, m, k // 2], dtype fp8 + output_scale [g, k // 4, m // 2 // 128], dtype int32 + quant_group_size int + masked_m shape [g] + """ + + assert input.is_contiguous() + assert len(input.shape) == 2 + assert input.shape[-1] % 2 == 0 + + # FP8 quantization parameters + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = -fp8_max + + m, k = input.shape + + # Create output + output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda") + + # Create output scale + alignment = 4 + scale_k = ceil_div(k, quant_group_size) + m_padded = align(m, alignment) + scale_k_padded = align(scale_k, alignment) + output_scale = torch.zeros((scale_k_padded // 4, m_padded), + dtype=torch.int32, + device='cuda') + + # Get block/grid/stage/warp + BLOCK_NUM_PER_EXPERT = 64 + + BLOCK = quant_group_size * 4 + num_warps = 1 + NUM_STAGES = 6 + hidden_dim_split_block_num = triton.cdiv(k, BLOCK) + grid = ( + hidden_dim_split_block_num, + BLOCK_NUM_PER_EXPERT, + 1, + ) + _per_token_quant_and_transform_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + output_scale, + *output_scale.stride(), + m, + k, + fp8_max, + fp8_min, + BLOCK=BLOCK, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + SCALE_UE8M0=scale_ue8m0, + ) + output_scale = output_scale.transpose(0, 1)[:m, :] + check_sf_layout( + output_scale, + m, + k, + (1, 128), + num_groups=None, + tma_stride_check=True, + ) + return output, output_scale diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 8f4e2459f1b..5e9f2ba1a26 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -8,6 +8,14 @@ def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape @@ -33,6 +41,33 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x_view.size(2)) +def per_token_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n), sf + + +def per_block_cast_to_fp8_e8m0( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, 128), align(n, 128)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2)) + + def calc_diff(x, y): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 367f7300b09..51a7758d281 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -9,7 +9,8 @@ import pytest import torch import torch.nn as nn -from _torch.helpers import per_block_cast_to_fp8 +from _torch.helpers import (per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0, + per_token_cast_to_fp8_e8m0) from mpi4py import MPI from mpi4py.futures import MPIPoolExecutor from utils.util import (skip_neither_ada_nor_hopper_unittest, @@ -25,6 +26,8 @@ VanillaMoE, WideEPMoE) from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \ CuteDslFusedMoE +from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \ + DeepGemmFusedMoE from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \ AlltoallMethodType from tensorrt_llm._torch.modules.gated_mlp import GatedMLP @@ -379,6 +382,183 @@ def set_tensor_value_4(x, num_row, num_cols): x.copy_(repeated) +@skip_pre_blackwell +@pytest.mark.parametrize( + "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", + product( + [torch.bfloat16], + [72], + [128, 256, 384, 512, 1024, 2048, 4096, 8192], + [2560], + [DefaultMoeRoutingMethod], + ), +) +def test_fused_moe_fp8_blockwise_deepgemm(dtype, + num_experts, + seq_len, + hidden_size, + RoutingMethodCls, + mapping=None): + SEQ_LEN = seq_len + HIDDEN_SIZE = hidden_size + INTERMEDIATE_SIZE = 256 + NUM_EXPERTS = num_experts + TOP_K = 2 + + routing_method = RoutingMethodCls(top_k=TOP_K) + + mapping = mapping or Mapping() + mapping.rank = mpi_rank() + torch.cuda.set_device(mapping.rank) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() + # Note: we use some special values init x and weight, otherwise the test will false positive failed. + set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE) + + x = x.cuda() + router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda() + + weights = {} + w3_w1_weight_scales = [] + w2_weight_scales = [] + for expert_id in range(NUM_EXPERTS): + w1_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE + w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE), + dtype=dtype).cuda() + w3_weight = torch.randn( + (INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE + set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE) + set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE) + + w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8_e8m0(w1_weight) + w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8_e8m0(w2_weight) + w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda() + + w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8_e8m0(w3_weight) + w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda() + + weights[f"{expert_id}.w1.weight"] = w1_weight_fp8 + weights[f"{expert_id}.w2.weight"] = w2_weight_fp8 + weights[f"{expert_id}.w3.weight"] = w3_weight_fp8 + weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale + weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale + weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale + weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale + + w3_w1_weight_scales.append( + torch.cat([w3_weight_scale, w1_weight_scale], dim=0)) + w2_weight_scales.append(w2_weight_scale) + + w3_w1_weight_scales = torch.stack(w3_w1_weight_scales, dim=0).cuda() + w2_weight_scales = torch.stack(w2_weight_scales, dim=0).cuda() + + quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES) + + fused_moe = DeepGemmFusedMoE( + num_experts=NUM_EXPERTS, + routing_method=routing_method, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + dtype=dtype, + reduce_results=True, + model_config=ModelConfig(quant_config=quant_config, mapping=mapping), + ) + fused_moe.cuda() + fused_moe.load_weights([weights]) + + def swiglu_fused_moe(x): + x, gate = x.chunk(2, dim=-1) + return torch.nn.functional.silu(gate) * x + + def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor, + b_sf: torch.Tensor, + offset_array: torch.Tensor) -> torch.Tensor: + d = torch.empty((a.shape[0], b.shape[1]), + device=b.device, + dtype=torch.bfloat16) + m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) + for idx in range(offset_array.numel() - 1): + m_indices[offset_array[idx]:offset_array[idx + 1]] = idx + + num_groups, n, k_ = b.shape + d = torch.empty((a.shape[0], b.shape[1]), + device=b.device, + dtype=torch.bfloat16) + m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32) + for idx in range(offset_array.numel() - 1): + m_indices[offset_array[idx]:offset_array[idx + 1]] = idx + + for g in range(num_groups): + aa = a[offset_array[g]:offset_array[g + 1], :].to(torch.bfloat16) + aa_sf = a_sf[offset_array[g]:offset_array[g + 1], :] + aa_dq = aa * aa_sf.repeat_interleave( + 128, dim=1)[:aa.shape[0], :aa.shape[1]] + bb = b[g, :, :].to(torch.bfloat16) + bb_sf = b_sf[g, :, :] + bb_dq = bb * bb_sf.repeat_interleave(128, dim=0).repeat_interleave( + 128, dim=1)[:bb.shape[0], :bb.shape[1]] + d[offset_array[g]:offset_array[g + 1], :] = (aa_dq @ bb_dq.t()) + return d + + token_selected_experts, token_final_scales = routing_method.apply( + router_logits) + t_idx = 0 + permuted_data_tensor = torch.empty((x.shape[0] * TOP_K, x.shape[1]), + device=x.device, + dtype=torch.bfloat16) + expert_first_token_offset_tensor = torch.zeros(NUM_EXPERTS + 1, + dtype=torch.int32) + unpermute_map = [] + scales = [] + for e_idx in range(NUM_EXPERTS): + for idx, token in enumerate(x): + for i, selected_expert in enumerate(token_selected_experts[idx]): + if e_idx == selected_expert: + permuted_data_tensor[t_idx, :] = token + unpermute_map.append(idx) + scales.append(token_final_scales[idx, i]) + t_idx += 1 + expert_first_token_offset_tensor[e_idx + 1] = t_idx + + act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0( + permuted_data_tensor) + h1 = grouped_gemm( + a=act_input_fp8, + b=fused_moe.w3_w1_weight, + a_sf=act_input_sf, + b_sf=w3_w1_weight_scales, + offset_array=expert_first_token_offset_tensor, + ) + h2 = swiglu_fused_moe(h1) + act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0(h2) + h3 = grouped_gemm( + a=act_input_fp8, + b=fused_moe.w2_weight, + a_sf=act_input_sf, + b_sf=w2_weight_scales, + offset_array=expert_first_token_offset_tensor, + ) + ref_output = torch.zeros_like(x) + for token_idx, h3_token in enumerate(h3): + original_idx = unpermute_map[token_idx] + ref_output[original_idx, :] += h3_token * scales[token_idx] + + with torch.inference_mode(): + output = fused_moe.forward(x, router_logits) + + # compare + torch.cuda.synchronize() + torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1) + + @skip_non_hopper_unittest @pytest.mark.parametrize( "dtype, num_experts, seq_len, hidden_size, RoutingMethodCls", diff --git a/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py b/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py index 44662f648a2..0fb78c01f33 100644 --- a/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py +++ b/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py @@ -18,10 +18,48 @@ import pytest import torch -from _torch.helpers import calc_diff, per_block_cast_to_fp8 +from _torch.helpers import (calc_diff, per_block_cast_to_fp8, + per_block_cast_to_fp8_e8m0, + per_token_cast_to_fp8_e8m0) from utils.util import getSMVersion +@pytest.mark.skipif( + getSMVersion() != 100, + reason="The test is for Blackwell only. Current SM is %d." % getSMVersion(), +) +@pytest.mark.parametrize( + "k, n", + [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), + (2048, 7168), (1024, 1024)], +) +@pytest.mark.parametrize( + "m", + [7, 64, 128, 4096], +) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16], +) +def test_fp8_block_scale_deep_gemm(dtype, m, k, n): + torch.random.manual_seed(0) + a = torch.randn((m, k), device='cuda', dtype=dtype) + b = torch.randn((n, k), device='cuda', dtype=dtype) + + act_a_fp8, act_a_sf = per_token_cast_to_fp8_e8m0(a) + act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b) + + output_expected = a @ b.t() + import deep_gemm + output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]), + device=act_a_fp8.device, + dtype=torch.bfloat16) + + deep_gemm.fp8_gemm_nt((act_a_fp8, act_a_sf), (act_b_fp8, act_b_sf), output) + diff = calc_diff(output, output_expected) + assert diff < 1e-2 + + @pytest.mark.skipif( getSMVersion() != 100 and getSMVersion() != 89, reason="The test is for Blackwell and Ada only. Current SM is %d." % diff --git a/tests/unittest/test_pip_install.py b/tests/unittest/test_pip_install.py index 11288e09cec..d75bfbaf420 100644 --- a/tests/unittest/test_pip_install.py +++ b/tests/unittest/test_pip_install.py @@ -51,6 +51,9 @@ def test_pip_install(): help="The wheel path") args = parser.parse_args() + if not os.environ.get("CUDA_HOME"): + os.environ["CUDA_HOME"] = "/usr/local/cuda" + print("########## Install required system libs ##########") if not os.path.exists("/usr/local/mpi/bin/mpicc"): subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True)