Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3089,7 +3089,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
Expand All @@ -3106,9 +3106,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc2_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
Expand Down Expand Up @@ -4523,7 +4523,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
<dt><tt>expert_weight_bits</tt> : int</dt>
<dd>Number of bits used in quantized weights. Default is 4 bits</dd>
<dt><tt>k</tt> : int</dt>
Expand All @@ -4542,11 +4542,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
<dd>2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
Expand Down
11 changes: 6 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Do not modify directly.*
|||[7, 21]|**T** = tensor(float)|
|Atanh|*in* input:**T**<br> *out* output:**T**|22+|**T** = tensor(float)|
|||[9, 21]|**T** = tensor(float)|
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)<br/> **U** = tensor(bool), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(float)|
|||[19, 21]|**T** = tensor(float)|
|||[11, 18]|**T** = tensor(float)|
Expand All @@ -58,11 +59,11 @@ Do not modify directly.*
|BitwiseOr|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BlackmanWindow|*in* size:**T1**<br> *out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)<br/> **T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Cast|*in* input:**T1**<br> *out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Celu|*in* X:**T**<br> *out* Y:**T**|12+|**T** = tensor(float)|
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {

ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_, use_sparse_mixer_);
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
activation_type_,
fc3_experts_weights_optional != nullptr,
normalize_routing_weights_,
use_sparse_mixer_);

size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ enum class ActivationType { Gelu,
GeGLU,
ReGLU,
SiGLU,
SwiGLU,
Identity,
InvalidType };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,10 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(const T* A, con
dispatch_moe_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
sm_, multi_processor_count_, stream, occupancy);
} else if (sm_ >= 80 && sm_ < 90) {
} else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels.
dispatch_moe_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config,
sm_, multi_processor_count_, stream, occupancy);
} else {
ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
}
}

Expand Down Expand Up @@ -478,6 +476,7 @@ void MoeGemmRunner<T, WeightType>::moe_gemm_bias_act(const T* A, const WeightTyp
int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, ActivationType activation_type,
cudaStream_t stream) {
// Swiglu will use Identity to call this function so we not need to handle it here.
switch (activation_type) {
case ActivationType::Relu:
run_gemm<EpilogueOpDefaultReLU>(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n,
Expand Down
Loading
Loading