Skip to content

Commit 8654241

Browse files
committed
Update MoE and qMoE spec (#25619)
### Weight Shape Update Make sure the shape reflects actual memory layout. The weight is stored in column major. ### Add support for SwiGLU activation attributes Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by introducing a few new attributes. For reference, see the [Triton kernel implementation](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py). #### New Attributes for SwiGLU * **`swiglu_fusion`**: * `0`: Not fused — two separate GEMMs (FC1 and FC3). * `1`: Fused GEMMs using **interleaved** format (g and l are interleaved per row). * `2`: Fused GEMMs using **non-interleaved** (concatenated) format. * **`swiglu_limit`**: Clamp threshold applied to `g` and `l`. * **`activation_alpha`**: Scalar multiplier applied to `g` before sigmoid. * **`activation_beta`**: Added to `l` before the final output computation. --- ### SwiGLU Activation Function The SwiGLU function is defined as: ``` g = xW + b l = xV + c G = min(g, limit) L = max(min(l, limit), -limit) swiglu = G * sigmoid(alpha * G) * (L + beta) ``` * `x`: Input * `W`, `V`: Weight matrices * `b`, `c`: Bias vectors * `alpha`, `beta`, `limit`: Float constants --- ### Fusion Behavior * When `swiglu_fusion = 0`: * Two GEMMs are computed independently. * FC1 → computes `g`, FC3 → computes `l`. * When `swiglu_fusion = 1`: * `g` and `l` are computed in a **single fused GEMM** (FC1). * Output is **interleaved** per row as: `gate, linear, gate, linear, ...`. * When `swiglu_fusion = 2`: * `g` and `l` are computed in a single GEMM (FC1). * Output is **concatenated** per row: `[g | l]`. ### Implement swiglu_limit for CUDA Update CUDA kernel to use default swiglu limit. Update test_moe_cuda.py to have same logic in reference implementation. ### Remaining Works The main purpose of this PR is to update spec instead of implementing them. Note that MoE/qMoE ops and tests still use hard-coded parameters and will be changed later to read from those attributes. Column-wise symmetric quantization is used for qMoE. We will add more quantization details when we add support of block-wise quantization soon.
1 parent d83904b commit 8654241

File tree

14 files changed

+337
-488
lines changed

14 files changed

+337
-488
lines changed

docs/ContribOperators.md

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,6 +3079,17 @@ This version of the operator has been available since version 1 of the 'com.micr
30793079
Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
30803080
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
30813081
usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).
3082+
3083+
The SwiGLU (Swish-Gated Linear Unit) activation function is like:
3084+
g = xW + b
3085+
l = xV + c
3086+
G = clamp(g, max=limit)
3087+
L = clamp(l, min=-limit, max=limit)
3088+
swiglu = G * sigmoid(alpha * G) * (L + beta)
3089+
where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
3090+
When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
3091+
When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
3092+
When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.
30823093
30833094

30843095
#### Version
@@ -3088,12 +3099,20 @@ This version of the operator has been available since version 1 of the 'com.micr
30883099
#### Attributes
30893100

30903101
<dl>
3102+
<dt><tt>activation_alpha</tt> : float</dt>
3103+
<dd>Alpha parameter used in activation function.</dd>
3104+
<dt><tt>activation_beta</tt> : float</dt>
3105+
<dd>Beta parameter used in activation function.</dd>
30913106
<dt><tt>activation_type</tt> : string</dt>
30923107
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
30933108
<dt><tt>k</tt> : int</dt>
30943109
<dd>Number of top experts to select from expert pool</dd>
30953110
<dt><tt>normalize_routing_weights</tt> : int</dt>
30963111
<dd>Whether to normalize routing weights</dd>
3112+
<dt><tt>swiglu_fusion</tt> : int</dt>
3113+
<dd>0: not fused, 1: fused and interleaved. 2: fused and not interleaved.</dd>
3114+
<dt><tt>swiglu_limit</tt> : float</dt>
3115+
<dd>The limit used to clamp in SwiGLU. No clamp when limit is not provided.</dd>
30973116
<dt><tt>use_sparse_mixer</tt> : int</dt>
30983117
<dd>Whether to use sparse mixer</dd>
30993118
</dl>
@@ -3106,15 +3125,15 @@ This version of the operator has been available since version 1 of the 'com.micr
31063125
<dt><tt>router_probs</tt> : T</dt>
31073126
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
31083127
<dt><tt>fc1_experts_weights</tt> : T</dt>
3109-
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu</dd>
3128+
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu</dd>
31103129
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
31113130
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
31123131
<dt><tt>fc2_experts_weights</tt> : T</dt>
3113-
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
3132+
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
31143133
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
31153134
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
31163135
<dt><tt>fc3_experts_weights</tt> (optional) : T</dt>
3117-
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size)</dd>
3136+
<dd>3D optional input tensor with shape (num_experts, inter_size, hidden_size)</dd>
31183137
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
31193138
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
31203139
</dl>
@@ -4522,6 +4541,10 @@ This version of the operator has been available since version 1 of the 'com.micr
45224541
#### Attributes
45234542

45244543
<dl>
4544+
<dt><tt>activation_alpha</tt> : float</dt>
4545+
<dd>Alpha parameter used in activation function.</dd>
4546+
<dt><tt>activation_beta</tt> : float</dt>
4547+
<dd>Beta parameter used in activation function.</dd>
45254548
<dt><tt>activation_type</tt> : string</dt>
45264549
<dd>Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu</dd>
45274550
<dt><tt>expert_weight_bits</tt> : int</dt>
@@ -4530,6 +4553,10 @@ This version of the operator has been available since version 1 of the 'com.micr
45304553
<dd>Number of top experts to select from expert pool</dd>
45314554
<dt><tt>normalize_routing_weights</tt> : int</dt>
45324555
<dd>Whether to normalize routing weights</dd>
4556+
<dt><tt>swiglu_fusion</tt> : int</dt>
4557+
<dd>0: not fused, 1: fused and interleaved. 2: fused and not interleaved.</dd>
4558+
<dt><tt>swiglu_limit</tt> : float</dt>
4559+
<dd>The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.</dd>
45334560
<dt><tt>use_sparse_mixer</tt> : int</dt>
45344561
<dd>Whether to use sparse mixer</dd>
45354562
</dl>
@@ -4542,19 +4569,19 @@ This version of the operator has been available since version 1 of the 'com.micr
45424569
<dt><tt>router_probs</tt> : T</dt>
45434570
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
45444571
<dt><tt>fc1_experts_weights</tt> : T1</dt>
4545-
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).</dd>
4572+
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.</dd>
45464573
<dt><tt>fc1_scales</tt> : T2</dt>
45474574
<dd>2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
45484575
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
45494576
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
45504577
<dt><tt>fc2_experts_weights</tt> : T1</dt>
4551-
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
4578+
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits</dd>
45524579
<dt><tt>fc2_scales</tt> : T2</dt>
45534580
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
45544581
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
45554582
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
45564583
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
4557-
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
4584+
<dd>3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
45584585
<dt><tt>fc3_scales</tt> (optional) : T2</dt>
45594586
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
45604587
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
@@ -4575,8 +4602,8 @@ This version of the operator has been available since version 1 of the 'com.micr
45754602
<dd>Constrain input and output types to float tensors.</dd>
45764603
<dt><tt>T1</tt> : tensor(uint8)</dt>
45774604
<dd>Constrain weights type to uint8 tensors.</dd>
4578-
<dt><tt>T2</tt> : tensor(float), tensor(float16)</dt>
4579-
<dd>Constrain scales type to float or float16 tensors.</dd>
4605+
<dt><tt>T2</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
4606+
<dd>Constrain scales type to float tensors.</dd>
45804607
</dl>
45814608

45824609

onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h

Lines changed: 1 addition & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,11 @@
66
#include "core/common/common.h"
77
#include "core/framework/tensor_shape.h"
88
#include "core/framework/op_kernel.h"
9+
#include "contrib_ops/cpu/quantization/moe_helper.h"
910

1011
namespace onnxruntime {
1112
namespace contrib {
1213

13-
enum class MoEParallelType {
14-
None = 0,
15-
EP = 1,
16-
TP = 2,
17-
EPAndTP = 3,
18-
};
19-
20-
enum class MoEQuantType {
21-
None = 0,
22-
UINT4 = 1,
23-
UINT8 = 2,
24-
};
25-
2614
enum class ActivationType {
2715
Relu = 0,
2816
Gelu = 1,
@@ -31,183 +19,7 @@ enum class ActivationType {
3119
SwiGLU = 4,
3220
};
3321

34-
struct MoEParameters {
35-
MoEParameters() {}
36-
explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
37-
int64_t num_rows;
38-
int64_t num_experts;
39-
int64_t local_num_experts;
40-
int64_t hidden_size;
41-
int64_t inter_size;
42-
43-
MoEParallelType parallel_type;
44-
int64_t tensor_shards{1};
45-
};
46-
4722
class MoEBaseCPU {
48-
public:
49-
Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input,
50-
const Tensor* router_probs, const Tensor* fc1_experts_weights,
51-
const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights,
52-
const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional,
53-
const Tensor* fc3_experts_bias_optional) const {
54-
ORT_UNUSED_PARAMETER(fc3_experts_bias_optional);
55-
const auto& input_dims = input->Shape().GetDims();
56-
const auto& router_probs_dims = router_probs->Shape().GetDims();
57-
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
58-
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();
59-
60-
int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
61-
int64_t hidden_size = input_dims[input_dims.size() - 1];
62-
int64_t local_num_experts = fc1_experts_weights_dims[0];
63-
int64_t num_experts = router_probs_dims[1];
64-
int64_t inter_size = fc2_experts_weights_dims[1];
65-
66-
if (fc1_experts_weights_dims.size() != 3) {
67-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ",
68-
fc1_experts_weights_dims.size());
69-
}
70-
if (fc2_experts_weights_dims.size() != 3) {
71-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ",
72-
fc2_experts_weights_dims.size());
73-
}
74-
if (fc1_experts_weights_dims[1] != hidden_size) {
75-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
76-
"fc1_experts_weights_dims[1] must be equal to hidden_size, got ",
77-
fc1_experts_weights_dims[1], " and ", hidden_size);
78-
}
79-
if (fc2_experts_weights_dims[1] != inter_size) {
80-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
81-
"fc2_experts_weights_dims[1] must be equal to inter_size, got ",
82-
fc2_experts_weights_dims[1], " and ", inter_size);
83-
}
84-
85-
const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1;
86-
const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x weights for gate
87-
88-
if (fc1_experts_weights_dims[2] != act * inter_size / coe) {
89-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
90-
"fc1_experts_weights_dims[2] is ", fc1_experts_weights_dims[2],
91-
" expected ", act * inter_size / coe);
92-
}
93-
if (fc2_experts_weights_dims[2] != hidden_size / coe) {
94-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
95-
"fc2_experts_weights_dims[2] must be equal to hidden_size, got ",
96-
fc2_experts_weights_dims[2], " and ", hidden_size);
97-
}
98-
99-
if (router_probs_dims.size() != 2) {
100-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ",
101-
router_probs_dims.size());
102-
}
103-
if (router_probs_dims[0] != num_rows) {
104-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ",
105-
router_probs_dims[0], " and ", num_rows);
106-
}
107-
108-
// Optional bias validation
109-
if (fc1_experts_bias_optional != nullptr) {
110-
const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims();
111-
if (fc1_experts_bias_dims.size() != 2) {
112-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ",
113-
fc1_experts_bias_dims.size());
114-
}
115-
if (fc1_experts_bias_dims[0] != local_num_experts) {
116-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ",
117-
fc1_experts_bias_dims[0], " and ", local_num_experts);
118-
}
119-
int64_t expected_fc1_bias_dim1 = activation_type_ == ActivationType::SwiGLU ? 2 * inter_size : inter_size;
120-
if (fc1_experts_bias_dims[1] != expected_fc1_bias_dim1) {
121-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[1] must be equal to ", expected_fc1_bias_dim1, ", got ",
122-
fc1_experts_bias_dims[1], " and inter_size=", inter_size, ". Activation type: ", static_cast<int>(activation_type_));
123-
}
124-
}
125-
if (fc2_experts_bias_optional != nullptr) {
126-
const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims();
127-
if (fc2_experts_bias_dims.size() != 2) {
128-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ",
129-
fc2_experts_bias_dims.size());
130-
}
131-
if (fc2_experts_bias_dims[0] != local_num_experts) {
132-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[0] must be equal to local_num_experts, got ",
133-
fc2_experts_bias_dims[0], " and ", local_num_experts);
134-
}
135-
if (fc2_experts_bias_dims[1] != hidden_size) {
136-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[1] must be equal to hidden_size, got ",
137-
fc2_experts_bias_dims[1], " and ", hidden_size);
138-
}
139-
}
140-
141-
// FC3 validation - match CUDA FasterTransformer behavior
142-
if (activation_type_ == ActivationType::SwiGLU && fc3_experts_weights_optional != nullptr) {
143-
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
144-
"SwiGLU activation is not supported with fc3.");
145-
}
146-
if (fc3_experts_weights_optional != nullptr && activation_type_ != ActivationType::SwiGLU) {
147-
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
148-
"FC3 gating is not yet implemented on CPU.");
149-
}
150-
151-
// Set output parameters
152-
parameters.num_rows = num_rows;
153-
parameters.num_experts = num_experts;
154-
parameters.local_num_experts = local_num_experts;
155-
parameters.hidden_size = hidden_size;
156-
parameters.inter_size = inter_size;
157-
parameters.parallel_type = MoEParallelType::None;
158-
159-
return Status::OK();
160-
}
161-
162-
Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, const Tensor* fc3_experts_scales_optional,
163-
int64_t num_experts, int64_t hidden_size, int64_t inter_size) const {
164-
if (fc1_experts_scales == nullptr || fc2_experts_scales == nullptr) {
165-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales and fc2_experts_scales cannot be null for quantized MoE");
166-
}
167-
168-
// SwiGLU should not use separate FC3 scales - weights are concatenated in FC1
169-
if (activation_type_ == ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) {
170-
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
171-
"SwiGLU activation is not supported with fc3.");
172-
}
173-
if (activation_type_ != ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) {
174-
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
175-
"FC3 gating is not yet implemented on CPU.");
176-
}
177-
178-
const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims();
179-
const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims();
180-
181-
if (fc1_experts_scales_dims.size() != 2) {
182-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ",
183-
fc1_experts_scales_dims.size());
184-
}
185-
if (fc2_experts_scales_dims.size() != 2) {
186-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ",
187-
fc2_experts_scales_dims.size());
188-
}
189-
if (fc1_experts_scales_dims[0] != num_experts) {
190-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ",
191-
fc1_experts_scales_dims[0], " and ", num_experts);
192-
}
193-
194-
const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x scales
195-
if (fc1_experts_scales_dims[1] != act * inter_size) {
196-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] is ", fc1_experts_scales_dims[1],
197-
" expected ", act * inter_size);
198-
}
199-
if (fc2_experts_scales_dims[0] != num_experts) {
200-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ",
201-
fc2_experts_scales_dims[0], " and ", num_experts);
202-
}
203-
if (fc2_experts_scales_dims[1] != hidden_size) {
204-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ",
205-
fc2_experts_scales_dims[1], " and ", hidden_size);
206-
}
207-
208-
return Status::OK();
209-
}
210-
21123
protected:
21224
MoEBaseCPU(const OpKernelInfo& op_kernel_info) {
21325
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());

0 commit comments

Comments
 (0)