Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(float16), tensor(uint8)<br/> **T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE);
// ******** End: Quantization ******************* //

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Expand Down Expand Up @@ -275,6 +276,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/moe/moe_helper.h"
#include "moe_helper.h"

Check warning on line 9 in onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h:9: Include the directory when naming header files [build/include_subdir] [4]
#include <limits>

namespace onnxruntime {
Expand Down
605 changes: 605 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/moe/moe_base_cpu.h"

namespace onnxruntime {
namespace contrib {

template <typename T>
class MoE final : public OpKernel, public MoEBaseCPU {
public:
explicit MoE(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* context) const override;

private:
Status ComputeMoE(const OpKernelContext* context,
const Tensor* input,
const Tensor* router_probs,
const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias,
const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias,
Tensor* output) const;

Status ProcessExpertBatch(const T* input_tokens,
const int64_t* token_expert_ids,
const float* token_weights,
int64_t num_tokens,
int64_t expert_id,
const T* fc1_weights,
const T* fc1_bias,
const T* fc2_weights,
const T* fc2_bias,
T* output_buffer,
int64_t hidden_size,
int64_t inter_size,
T* fc1_output_buffer,
T* activation_output_buffer) const;

Status ComputeGEMM(const T* A, const T* B, T* C,
int64_t M, int64_t K, int64_t N,
bool transpose_B = false) const;

void ApplyActivationVectorized(T* data, int64_t size) const;
void ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const;
};

} // namespace contrib
} // namespace onnxruntime
12 changes: 10 additions & 2 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,18 @@ void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t
gate_val = std::min(gate_val, clamp_limit);
linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit);

// Use numerically stable sigmoid computation (matches CUDA kernel behavior)
float sigmoid_arg = activation_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
float sigmoid_out;
if (sigmoid_arg > 0) {
float exp_neg = std::exp(-sigmoid_arg);
sigmoid_out = 1.0f / (1.0f + exp_neg);
} else {
float exp_pos = std::exp(sigmoid_arg);
sigmoid_out = exp_pos / (1.0f + exp_pos);
}

float swish_out = gate_val * sigmoid_out;
output_data[i] = swish_out * (linear_val + activation_beta);
}
} else {
Expand Down
91 changes: 91 additions & 0 deletions onnxruntime/test/contrib_ops/moe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,97 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {
#endif
}

// Test for CPU MoE implementation
static void RunMoECpuTest(const std::vector<float>& input, const std::vector<float>& router_probs,
const std::vector<float>& fc1_experts_weights, const std::vector<float>& fc2_experts_weights,
const std::vector<float>& fc3_experts_weights, const std::vector<float>& fc1_experts_bias,
const std::vector<float>& fc2_experts_bias, const std::vector<float>& output_data, int num_rows,
int num_experts, int hidden_size, int inter_size, std::string activation_type,
int normalize_routing_weights = 1, int top_k = 1) {
OpTester tester("MoE", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("k", static_cast<int64_t>(top_k));
tester.AddAttribute<std::string>("activation_type", activation_type);
tester.AddAttribute<int64_t>("normalize_routing_weights", static_cast<int64_t>(normalize_routing_weights));

bool is_swiglu = (activation_type == "swiglu");

if (is_swiglu) {
tester.AddAttribute<int64_t>("swiglu_fusion", static_cast<int64_t>(1));
tester.AddAttribute<float>("activation_beta", 1.0f);
}

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};

int64_t fc1_output_size = is_swiglu ? (2 * inter_size) : inter_size;

std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, fc1_output_size};
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
std::vector<int64_t> fc3_experts_weights_dims = fc1_experts_weights_dims;
std::vector<int64_t> fc1_experts_bias_dims = {num_experts, fc1_output_size};
std::vector<int64_t> fc2_experts_bias_dims = {num_experts, hidden_size};
std::vector<int64_t> output_dims = {num_rows, hidden_size};

tester.AddInput<float>("input", input_dims, input);
tester.AddInput<float>("router_probs", router_probs_dims, router_probs);
tester.AddInput<float>("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights);
if (!fc1_experts_bias.empty()) {
tester.AddInput<float>("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias);
} else {
tester.AddOptionalInputEdge<float>();
}
tester.AddInput<float>("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
if (!fc2_experts_bias.empty()) {
tester.AddInput<float>("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias);
} else {
tester.AddOptionalInputEdge<float>();
}
if (!fc3_experts_weights.empty()) {
tester.AddInput<float>("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights);
} else {
tester.AddOptionalInputEdge<float>();
}
tester.AddOptionalInputEdge<float>(); // fc3_experts_bias

tester.AddOutput<float>("output", output_dims, output_data);
tester.SetOutputTolerance(0.05f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

TEST(MoETest, MoECpuTest_BasicSwiGLU) {
int num_rows = 2;
int num_experts = 2;
int hidden_size = 4;
int inter_size = 8;

// Simple test data
const std::vector<float> input = {
1.0f, 2.0f, 3.0f, 4.0f,
5.0f, 6.0f, 7.0f, 8.0f};

const std::vector<float> router_probs = {
0.8f, 0.2f,
0.3f, 0.7f};

const std::vector<float> fc1_experts_weights(num_experts * hidden_size * (2 * inter_size), 0.1f);

const std::vector<float> fc2_experts_weights(num_experts * inter_size * hidden_size, 0.1f);

const std::vector<float> fc3_experts_weights = {}; // No FC3
const std::vector<float> fc1_experts_bias = {}; // No bias
const std::vector<float> fc2_experts_bias = {}; // No bias

const std::vector<float> output_data = {
1.169694f, 1.169694f, 1.169694f, 1.169694f,
6.970291f, 6.970291f, 6.970291f, 6.970291f};

RunMoECpuTest(input, router_probs, fc1_experts_weights, fc2_experts_weights,
fc3_experts_weights, fc1_experts_bias, fc2_experts_bias, output_data,
num_rows, num_experts, hidden_size, inter_size, "swiglu");
}
#endif

} // namespace test
Expand Down
Loading
Loading