Skip to content
Merged
43 changes: 35 additions & 8 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,17 @@ This version of the operator has been available since version 1 of the 'com.micr
Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).

The SwiGLU (Swish-Gated Linear Unit) activation function is like:
g = xW + b
l = xV + c
G = clamp(g, max=limit)
L = clamp(l, min=-limit, max=limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.


#### Version
Expand All @@ -3088,12 +3099,20 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>activation_alpha</tt> : float</dt>
<dd>Alpha parameter used in activation function.</dd>
<dt><tt>activation_beta</tt> : float</dt>
<dd>Beta parameter used in activation function.</dd>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>swiglu_fusion</tt> : int</dt>
<dd>0: not fused, 1: fused and interleaved. 2: fused and not interleaved.</dd>
<dt><tt>swiglu_limit</tt> : float</dt>
<dd>The limit used to clamp in SwiGLU. No clamp when limit is not provided.</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>
Expand All @@ -3106,15 +3125,15 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu</dd>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc2_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dd>3D optional input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
</dl>
Expand Down Expand Up @@ -4522,6 +4541,10 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>activation_alpha</tt> : float</dt>
<dd>Alpha parameter used in activation function.</dd>
<dt><tt>activation_beta</tt> : float</dt>
<dd>Beta parameter used in activation function.</dd>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
<dt><tt>expert_weight_bits</tt> : int</dt>
Expand All @@ -4530,6 +4553,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
<dt><tt>swiglu_fusion</tt> : int</dt>
<dd>0: not fused, 1: fused and interleaved. 2: fused and not interleaved.</dd>
<dt><tt>swiglu_limit</tt> : float</dt>
<dd>The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.</dd>
<dt><tt>use_sparse_mixer</tt> : int</dt>
<dd>Whether to use sparse mixer</dd>
</dl>
Expand All @@ -4542,19 +4569,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).</dd>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.</dd>
<dt><tt>fc1_scales</tt> : T2</dt>
<dd>2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits</dd>
<dt><tt>fc2_scales</tt> : T2</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
<dd>3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc3_scales</tt> (optional) : T2</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
Expand All @@ -4575,8 +4602,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>T1</tt> : tensor(uint8)</dt>
<dd>Constrain weights type to uint8 tensors.</dd>
<dt><tt>T2</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain scales type to float or float16 tensors.</dd>
<dt><tt>T2</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain scales type to float tensors.</dd>
</dl>


Expand Down
190 changes: 1 addition & 189 deletions onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,11 @@
#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/quantization/moe_helper.h"

namespace onnxruntime {
namespace contrib {

enum class MoEParallelType {
None = 0,
EP = 1,
TP = 2,
EPAndTP = 3,
};

enum class MoEQuantType {
None = 0,
UINT4 = 1,
UINT8 = 2,
};

enum class ActivationType {
Relu = 0,
Gelu = 1,
Expand All @@ -31,183 +19,7 @@ enum class ActivationType {
SwiGLU = 4,
};

struct MoEParameters {
MoEParameters() {}
explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
int64_t num_rows;
int64_t num_experts;
int64_t local_num_experts;
int64_t hidden_size;
int64_t inter_size;

MoEParallelType parallel_type;
int64_t tensor_shards{1};
};

class MoEBaseCPU {
public:
Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input,
const Tensor* router_probs, const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional,
const Tensor* fc3_experts_bias_optional) const {
ORT_UNUSED_PARAMETER(fc3_experts_bias_optional);
const auto& input_dims = input->Shape().GetDims();
const auto& router_probs_dims = router_probs->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();

int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = fc2_experts_weights_dims[1];

if (fc1_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ",
fc1_experts_weights_dims.size());
}
if (fc2_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ",
fc2_experts_weights_dims.size());
}
if (fc1_experts_weights_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_weights_dims[1] must be equal to hidden_size, got ",
fc1_experts_weights_dims[1], " and ", hidden_size);
}
if (fc2_experts_weights_dims[1] != inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_weights_dims[1] must be equal to inter_size, got ",
fc2_experts_weights_dims[1], " and ", inter_size);
}

const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1;
const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x weights for gate

if (fc1_experts_weights_dims[2] != act * inter_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_weights_dims[2] is ", fc1_experts_weights_dims[2],
" expected ", act * inter_size / coe);
}
if (fc2_experts_weights_dims[2] != hidden_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_weights_dims[2] must be equal to hidden_size, got ",
fc2_experts_weights_dims[2], " and ", hidden_size);
}

if (router_probs_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ",
router_probs_dims.size());
}
if (router_probs_dims[0] != num_rows) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ",
router_probs_dims[0], " and ", num_rows);
}

// Optional bias validation
if (fc1_experts_bias_optional != nullptr) {
const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims();
if (fc1_experts_bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ",
fc1_experts_bias_dims.size());
}
if (fc1_experts_bias_dims[0] != local_num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ",
fc1_experts_bias_dims[0], " and ", local_num_experts);
}
int64_t expected_fc1_bias_dim1 = activation_type_ == ActivationType::SwiGLU ? 2 * inter_size : inter_size;
if (fc1_experts_bias_dims[1] != expected_fc1_bias_dim1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[1] must be equal to ", expected_fc1_bias_dim1, ", got ",
fc1_experts_bias_dims[1], " and inter_size=", inter_size, ". Activation type: ", static_cast<int>(activation_type_));
}
}
if (fc2_experts_bias_optional != nullptr) {
const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims();
if (fc2_experts_bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ",
fc2_experts_bias_dims.size());
}
if (fc2_experts_bias_dims[0] != local_num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[0] must be equal to local_num_experts, got ",
fc2_experts_bias_dims[0], " and ", local_num_experts);
}
if (fc2_experts_bias_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[1] must be equal to hidden_size, got ",
fc2_experts_bias_dims[1], " and ", hidden_size);
}
}

// FC3 validation - match CUDA FasterTransformer behavior
if (activation_type_ == ActivationType::SwiGLU && fc3_experts_weights_optional != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SwiGLU activation is not supported with fc3.");
}
if (fc3_experts_weights_optional != nullptr && activation_type_ != ActivationType::SwiGLU) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"FC3 gating is not yet implemented on CPU.");
}

// Set output parameters
parameters.num_rows = num_rows;
parameters.num_experts = num_experts;
parameters.local_num_experts = local_num_experts;
parameters.hidden_size = hidden_size;
parameters.inter_size = inter_size;
parameters.parallel_type = MoEParallelType::None;

return Status::OK();
}

Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, const Tensor* fc3_experts_scales_optional,
int64_t num_experts, int64_t hidden_size, int64_t inter_size) const {
if (fc1_experts_scales == nullptr || fc2_experts_scales == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales and fc2_experts_scales cannot be null for quantized MoE");
}

// SwiGLU should not use separate FC3 scales - weights are concatenated in FC1
if (activation_type_ == ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SwiGLU activation is not supported with fc3.");
}
if (activation_type_ != ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"FC3 gating is not yet implemented on CPU.");
}

const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims();
const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims();

if (fc1_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ",
fc1_experts_scales_dims.size());
}
if (fc2_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ",
fc2_experts_scales_dims.size());
}
if (fc1_experts_scales_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ",
fc1_experts_scales_dims[0], " and ", num_experts);
}

const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x scales
if (fc1_experts_scales_dims[1] != act * inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] is ", fc1_experts_scales_dims[1],
" expected ", act * inter_size);
}
if (fc2_experts_scales_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ",
fc2_experts_scales_dims[0], " and ", num_experts);
}
if (fc2_experts_scales_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ",
fc2_experts_scales_dims[1], " and ", hidden_size);
}

return Status::OK();
}

protected:
MoEBaseCPU(const OpKernelInfo& op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());
Expand Down
Loading
Loading