-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add support for QMoE in CPU #25558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add support for QMoE in CPU #25558
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
3504f14
Initial QMoE CPU support
apsonawane 12aa6c3
Fix Lint error
apsonawane de5e7c5
Fix pipelines
apsonawane 2f9192e
Add SwiGLU support for CPU QMoE
apsonawane ddde845
Fix pipelines
apsonawane f55d780
Address comments
apsonawane a1a1f7c
Update contrib ops doc
kunal-vaishnavi 337d56a
Update emsdk
kunal-vaishnavi 0d16252
Revert "Update emsdk"
kunal-vaishnavi 67b4b1f
Address comments
apsonawane 0fcdc72
Address comments
apsonawane 9fdb2ff
Fix
apsonawane 6e60f09
Comments
apsonawane 2814dcd
Update to symmetric quantization
apsonawane 1958814
Fix build errors and update the tests
apsonawane 61ab80f
Add SwiGLU tests in python
apsonawane bd721db
Add SwiGLU clamping and fix docs pipeline
apsonawane ce1309f
Add back ep_graph tests
apsonawane 0729e44
Revert cuda test changes
apsonawane 271cba4
update doc
tianleiwu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/common/common.h" | ||
| #include "core/framework/tensor_shape.h" | ||
| #include "core/framework/op_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
||
| enum class MoEParallelType { | ||
| None = 0, | ||
| EP = 1, | ||
| TP = 2, | ||
| EPAndTP = 3, | ||
| }; | ||
|
|
||
| enum class MoEQuantType { | ||
| None = 0, | ||
| UINT4 = 1, | ||
| UINT8 = 2, | ||
| }; | ||
|
|
||
| enum class ActivationType { | ||
| Relu = 0, | ||
| Gelu = 1, | ||
| Silu = 2, | ||
| Identity = 3, | ||
| SwiGLU = 4, | ||
| }; | ||
|
|
||
| struct MoEParameters { | ||
| MoEParameters() {} | ||
| explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {} | ||
| int64_t num_rows; | ||
| int64_t num_experts; | ||
| int64_t local_num_experts; | ||
| int64_t hidden_size; | ||
| int64_t inter_size; | ||
|
|
||
| MoEParallelType parallel_type; | ||
| int64_t tensor_shards{1}; | ||
| }; | ||
|
|
||
| class MoEBaseCPU { | ||
| public: | ||
| Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input, | ||
apsonawane marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| const Tensor* router_probs, const Tensor* fc1_experts_weights, | ||
| const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights, | ||
| const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional, | ||
| const Tensor* fc3_experts_bias_optional) const { | ||
| ORT_UNUSED_PARAMETER(fc3_experts_bias_optional); | ||
| const auto& input_dims = input->Shape().GetDims(); | ||
| const auto& router_probs_dims = router_probs->Shape().GetDims(); | ||
| const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); | ||
| const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); | ||
|
|
||
| int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; | ||
| int64_t hidden_size = input_dims[input_dims.size() - 1]; | ||
| int64_t local_num_experts = fc1_experts_weights_dims[0]; | ||
| int64_t num_experts = router_probs_dims[1]; | ||
| int64_t inter_size = fc2_experts_weights_dims[1]; | ||
|
|
||
| if (fc1_experts_weights_dims.size() != 3) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", | ||
| fc1_experts_weights_dims.size()); | ||
| } | ||
| if (fc2_experts_weights_dims.size() != 3) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", | ||
| fc2_experts_weights_dims.size()); | ||
| } | ||
| if (fc1_experts_weights_dims[1] != hidden_size) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", | ||
| fc1_experts_weights_dims[1], " and ", hidden_size); | ||
| } | ||
| if (fc2_experts_weights_dims[1] != inter_size) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "fc2_experts_weights_dims[1] must be equal to inter_size, got ", | ||
| fc2_experts_weights_dims[1], " and ", inter_size); | ||
| } | ||
|
|
||
| const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; | ||
| const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x weights for gate | ||
|
|
||
| if (fc1_experts_weights_dims[2] != act * inter_size / coe) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "fc1_experts_weights_dims[2] is ", fc1_experts_weights_dims[2], | ||
| " expected ", act * inter_size / coe); | ||
| } | ||
| if (fc2_experts_weights_dims[2] != hidden_size / coe) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", | ||
| fc2_experts_weights_dims[2], " and ", hidden_size); | ||
| } | ||
|
|
||
| if (router_probs_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", | ||
| router_probs_dims.size()); | ||
| } | ||
| if (router_probs_dims[0] != num_rows) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", | ||
| router_probs_dims[0], " and ", num_rows); | ||
| } | ||
|
|
||
| // Optional bias validation | ||
| if (fc1_experts_bias_optional != nullptr) { | ||
| const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); | ||
| if (fc1_experts_bias_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", | ||
| fc1_experts_bias_dims.size()); | ||
| } | ||
| if (fc1_experts_bias_dims[0] != local_num_experts) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", | ||
| fc1_experts_bias_dims[0], " and ", local_num_experts); | ||
| } | ||
| int64_t expected_fc1_bias_dim1 = activation_type_ == ActivationType::SwiGLU ? 2 * inter_size : inter_size; | ||
| if (fc1_experts_bias_dims[1] != expected_fc1_bias_dim1) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[1] must be equal to ", expected_fc1_bias_dim1, ", got ", | ||
| fc1_experts_bias_dims[1], " and inter_size=", inter_size, ". Activation type: ", static_cast<int>(activation_type_)); | ||
| } | ||
| } | ||
| if (fc2_experts_bias_optional != nullptr) { | ||
| const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); | ||
| if (fc2_experts_bias_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", | ||
| fc2_experts_bias_dims.size()); | ||
| } | ||
| if (fc2_experts_bias_dims[0] != local_num_experts) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[0] must be equal to local_num_experts, got ", | ||
| fc2_experts_bias_dims[0], " and ", local_num_experts); | ||
| } | ||
| if (fc2_experts_bias_dims[1] != hidden_size) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", | ||
| fc2_experts_bias_dims[1], " and ", hidden_size); | ||
| } | ||
| } | ||
|
|
||
| // FC3 validation - match CUDA FasterTransformer behavior | ||
| if (activation_type_ == ActivationType::SwiGLU && fc3_experts_weights_optional != nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, | ||
| "SwiGLU activation is not supported with fc3."); | ||
| } | ||
| if (fc3_experts_weights_optional != nullptr && activation_type_ != ActivationType::SwiGLU) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, | ||
| "FC3 gating is not yet implemented on CPU."); | ||
| } | ||
|
|
||
| // Set output parameters | ||
| parameters.num_rows = num_rows; | ||
| parameters.num_experts = num_experts; | ||
| parameters.local_num_experts = local_num_experts; | ||
| parameters.hidden_size = hidden_size; | ||
| parameters.inter_size = inter_size; | ||
| parameters.parallel_type = MoEParallelType::None; | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, const Tensor* fc3_experts_scales_optional, | ||
| int64_t num_experts, int64_t hidden_size, int64_t inter_size) const { | ||
| if (fc1_experts_scales == nullptr || fc2_experts_scales == nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales and fc2_experts_scales cannot be null for quantized MoE"); | ||
| } | ||
|
|
||
| // SwiGLU should not use separate FC3 scales - weights are concatenated in FC1 | ||
| if (activation_type_ == ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, | ||
| "SwiGLU activation is not supported with fc3."); | ||
| } | ||
| if (activation_type_ != ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, | ||
| "FC3 gating is not yet implemented on CPU."); | ||
| } | ||
|
|
||
| const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims(); | ||
| const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims(); | ||
|
|
||
| if (fc1_experts_scales_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ", | ||
| fc1_experts_scales_dims.size()); | ||
| } | ||
| if (fc2_experts_scales_dims.size() != 2) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", | ||
| fc2_experts_scales_dims.size()); | ||
| } | ||
| if (fc1_experts_scales_dims[0] != num_experts) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", | ||
| fc1_experts_scales_dims[0], " and ", num_experts); | ||
| } | ||
|
|
||
| const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x scales | ||
| if (fc1_experts_scales_dims[1] != act * inter_size) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] is ", fc1_experts_scales_dims[1], | ||
| " expected ", act * inter_size); | ||
| } | ||
| if (fc2_experts_scales_dims[0] != num_experts) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ", | ||
| fc2_experts_scales_dims[0], " and ", num_experts); | ||
| } | ||
| if (fc2_experts_scales_dims[1] != hidden_size) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ", | ||
| fc2_experts_scales_dims[1], " and ", hidden_size); | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| protected: | ||
| MoEBaseCPU(const OpKernelInfo& op_kernel_info) { | ||
|
Check warning on line 212 in onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
|
||
| ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK()); | ||
|
|
||
| std::string activation_type_str; | ||
| ORT_ENFORCE(op_kernel_info.GetAttr<std::string>("activation_type", &activation_type_str).IsOK()); | ||
|
Check warning on line 216 in onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
|
||
| if (activation_type_str == "relu") { | ||
| activation_type_ = ActivationType::Relu; | ||
| } else if (activation_type_str == "gelu") { | ||
| activation_type_ = ActivationType::Gelu; | ||
| } else if (activation_type_str == "silu") { | ||
| activation_type_ = ActivationType::Silu; | ||
| } else if (activation_type_str == "identity") { | ||
| activation_type_ = ActivationType::Identity; | ||
| } else if (activation_type_str == "swiglu") { | ||
| activation_type_ = ActivationType::SwiGLU; | ||
| } else { | ||
| ORT_THROW("Unsupported MoE activation type: ", activation_type_str); | ||
| } | ||
|
|
||
| normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1; | ||
|
|
||
| use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1; | ||
| if (use_sparse_mixer_) { | ||
| ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); | ||
| } | ||
| } | ||
|
|
||
| bool normalize_routing_weights_; | ||
| bool use_sparse_mixer_; | ||
| int64_t k_; | ||
| ActivationType activation_type_; | ||
| }; | ||
|
|
||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "contrib_ops/cpu/moe/moe_utils.h" | ||
| #include <cmath> | ||
| #include <algorithm> | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
||
| float ApplyActivation(float x, ActivationType activation_type) { | ||
| switch (activation_type) { | ||
| case ActivationType::Relu: | ||
| return std::max(0.0f, x); | ||
| case ActivationType::Gelu: | ||
| return 0.5f * x * (1.0f + std::tanh(0.7978845608f * (x + 0.044715f * x * x * x))); | ||
| case ActivationType::Silu: | ||
| return x * (1.0f / (1.0f + std::exp(-x))); | ||
| case ActivationType::Identity: | ||
| return x; | ||
| case ActivationType::SwiGLU: | ||
| // SwiGLU: This is handled specially as it requires gating, not applied here | ||
| return x; | ||
| default: | ||
| return x; // Default to identity | ||
| } | ||
| } | ||
|
|
||
| // Helper method for applying SwiGLU activation with different memory layouts | ||
| void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) { | ||
| constexpr float swiglu_alpha = 1.702f; | ||
| constexpr float clamp_limit = 7.0f; // Clamping limit as specified | ||
|
|
||
| if (is_interleaved_format) { | ||
| // For interleaved format [linear, gate, linear, gate, ...], process directly | ||
| // Make a temporary copy of each pair of values before modifying them | ||
| for (int64_t i = 0; i < inter_size; ++i) { | ||
| const size_t idx = static_cast<size_t>(i); | ||
| const size_t linear_idx = 2 * idx; | ||
| const size_t gate_idx = linear_idx + 1; | ||
|
|
||
| // Store original values | ||
| float linear_val = data[linear_idx]; // Interleaved: even index | ||
| float gate_val = data[gate_idx]; // Interleaved: odd index | ||
|
|
||
| // Apply clamping to the values | ||
| if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only | ||
| if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max | ||
| if (linear_val < -clamp_limit) linear_val = -clamp_limit; | ||
|
|
||
| // SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1) | ||
| float sigmoid_arg = swiglu_alpha * gate_val; | ||
| float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); | ||
| float swish_out = gate_val * sigmoid_out; | ||
| float result = swish_out * (linear_val + 1.0f); | ||
|
|
||
| // Store result in first element (linear position) | ||
| data[idx] = result; | ||
| } | ||
| } else { | ||
| // For chunked layout [linear..., gate...], handle separately | ||
| // Need to work with original data in-place | ||
| // First, store all the gate computations since they depend on original gate values | ||
| std::vector<float> computed_gates(static_cast<size_t>(inter_size)); | ||
|
Check warning on line 64 in onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
|
||
|
|
||
| for (int64_t i = 0; i < inter_size; ++i) { | ||
| const size_t idx = static_cast<size_t>(i); | ||
| float gate_val = data[idx + static_cast<size_t>(inter_size)]; | ||
|
|
||
| // Apply clamping to the gate value (max only) | ||
| if (gate_val > clamp_limit) gate_val = clamp_limit; | ||
|
|
||
| // Compute the gate part of SwiGLU | ||
| float sigmoid_arg = swiglu_alpha * gate_val; | ||
| float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); | ||
| computed_gates[idx] = gate_val * sigmoid_out; | ||
| } | ||
|
|
||
| // Now apply the full activation with the precomputed gate values | ||
| for (int64_t i = 0; i < inter_size; ++i) { | ||
| const size_t idx = static_cast<size_t>(i); | ||
| float linear_val = data[idx]; | ||
|
|
||
| // Apply clamping to the linear value (min/max) | ||
| if (linear_val > clamp_limit) linear_val = clamp_limit; | ||
| if (linear_val < -clamp_limit) linear_val = -clamp_limit; | ||
|
|
||
| data[idx] = computed_gates[idx] * (linear_val + 1.0f); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
| #include <cstdint> | ||
| #include "contrib_ops/cpu/moe/moe_base_cpu.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
|
|
||
| float ApplyActivation(float x, ActivationType activation_type); | ||
| void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format); | ||
|
|
||
| } // namespace contrib | ||
| } // namespace onnxruntime |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.