Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -4571,12 +4571,12 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>T1</tt> : tensor(uint8)</dt>
<dd>Constrain weights type to uint8 tensors.</dd>
<dt><tt>T2</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain scales type to float tensors.</dd>
<dd>Constrain scales type to float or float16 tensors.</dd>
</dl>


Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE);
// ******** End: Quantization ******************* //

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Expand Down Expand Up @@ -271,6 +272,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
246 changes: 246 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace contrib {

enum class MoEParallelType {
None = 0,
EP = 1,
TP = 2,
EPAndTP = 3,
};

enum class MoEQuantType {
None = 0,
UINT4 = 1,
UINT8 = 2,
};

enum class ActivationType {
Relu = 0,
Gelu = 1,
Silu = 2,
Identity = 3,
SwiGLU = 4,
};

struct MoEParameters {
MoEParameters() {}
explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
int64_t num_rows;
int64_t num_experts;
int64_t local_num_experts;
int64_t hidden_size;
int64_t inter_size;

MoEParallelType parallel_type;
int64_t tensor_shards{1};
};

class MoEBaseCPU {
public:
Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input,
const Tensor* router_probs, const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional,
const Tensor* fc3_experts_bias_optional) const {
ORT_UNUSED_PARAMETER(fc3_experts_bias_optional);
const auto& input_dims = input->Shape().GetDims();
const auto& router_probs_dims = router_probs->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();

int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = fc2_experts_weights_dims[1];

if (fc1_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ",
fc1_experts_weights_dims.size());
}
if (fc2_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ",
fc2_experts_weights_dims.size());
}
if (fc1_experts_weights_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_weights_dims[1] must be equal to hidden_size, got ",
fc1_experts_weights_dims[1], " and ", hidden_size);
}
if (fc2_experts_weights_dims[1] != inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_weights_dims[1] must be equal to inter_size, got ",
fc2_experts_weights_dims[1], " and ", inter_size);
}

const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1;
const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x weights for gate

if (fc1_experts_weights_dims[2] != act * inter_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_weights_dims[2] is ", fc1_experts_weights_dims[2],
" expected ", act * inter_size / coe);
}
if (fc2_experts_weights_dims[2] != hidden_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_weights_dims[2] must be equal to hidden_size, got ",
fc2_experts_weights_dims[2], " and ", hidden_size);
}

if (router_probs_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ",
router_probs_dims.size());
}
if (router_probs_dims[0] != num_rows) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ",
router_probs_dims[0], " and ", num_rows);
}

// Optional bias validation
if (fc1_experts_bias_optional != nullptr) {
const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims();
if (fc1_experts_bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ",
fc1_experts_bias_dims.size());
}
if (fc1_experts_bias_dims[0] != local_num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ",
fc1_experts_bias_dims[0], " and ", local_num_experts);
}
int64_t expected_fc1_bias_dim1 = activation_type_ == ActivationType::SwiGLU ? 2 * inter_size : inter_size;
if (fc1_experts_bias_dims[1] != expected_fc1_bias_dim1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[1] must be equal to ", expected_fc1_bias_dim1, ", got ",
fc1_experts_bias_dims[1], " and inter_size=", inter_size, ". Activation type: ", static_cast<int>(activation_type_));
}
}
if (fc2_experts_bias_optional != nullptr) {
const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims();
if (fc2_experts_bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ",
fc2_experts_bias_dims.size());
}
if (fc2_experts_bias_dims[0] != local_num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[0] must be equal to local_num_experts, got ",
fc2_experts_bias_dims[0], " and ", local_num_experts);
}
if (fc2_experts_bias_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims[1] must be equal to hidden_size, got ",
fc2_experts_bias_dims[1], " and ", hidden_size);
}
}

// FC3 validation - match CUDA FasterTransformer behavior
if (activation_type_ == ActivationType::SwiGLU && fc3_experts_weights_optional != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SwiGLU activation is not supported with fc3.");
}
if (fc3_experts_weights_optional != nullptr && activation_type_ != ActivationType::SwiGLU) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"FC3 gating is not yet implemented on CPU.");
}

// Set output parameters
parameters.num_rows = num_rows;
parameters.num_experts = num_experts;
parameters.local_num_experts = local_num_experts;
parameters.hidden_size = hidden_size;
parameters.inter_size = inter_size;
parameters.parallel_type = MoEParallelType::None;

return Status::OK();
}

Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, const Tensor* fc3_experts_scales_optional,
int64_t num_experts, int64_t hidden_size, int64_t inter_size) const {
if (fc1_experts_scales == nullptr || fc2_experts_scales == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales and fc2_experts_scales cannot be null for quantized MoE");
}

// SwiGLU should not use separate FC3 scales - weights are concatenated in FC1
if (activation_type_ == ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SwiGLU activation is not supported with fc3.");
}
if (activation_type_ != ActivationType::SwiGLU && fc3_experts_scales_optional != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"FC3 gating is not yet implemented on CPU.");
}

const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims();
const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims();

if (fc1_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ",
fc1_experts_scales_dims.size());
}
if (fc2_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ",
fc2_experts_scales_dims.size());
}
if (fc1_experts_scales_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ",
fc1_experts_scales_dims[0], " and ", num_experts);
}

const int64_t act = activation_type_ == ActivationType::SwiGLU ? 2 : 1; // SwiGLU requires 2x scales
if (fc1_experts_scales_dims[1] != act * inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] is ", fc1_experts_scales_dims[1],
" expected ", act * inter_size);
}
if (fc2_experts_scales_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ",
fc2_experts_scales_dims[0], " and ", num_experts);
}
if (fc2_experts_scales_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ",
fc2_experts_scales_dims[1], " and ", hidden_size);
}

return Status::OK();
}

protected:
MoEBaseCPU(const OpKernelInfo& op_kernel_info) {

Check warning on line 212 in onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h:212: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());

std::string activation_type_str;
ORT_ENFORCE(op_kernel_info.GetAttr<std::string>("activation_type", &activation_type_str).IsOK());

Check warning on line 216 in onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h:216: Add #include <string> for string [build/include_what_you_use] [4]
if (activation_type_str == "relu") {
activation_type_ = ActivationType::Relu;
} else if (activation_type_str == "gelu") {
activation_type_ = ActivationType::Gelu;
} else if (activation_type_str == "silu") {
activation_type_ = ActivationType::Silu;
} else if (activation_type_str == "identity") {
activation_type_ = ActivationType::Identity;
} else if (activation_type_str == "swiglu") {
activation_type_ = ActivationType::SwiGLU;
} else {
ORT_THROW("Unsupported MoE activation type: ", activation_type_str);
}

normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault<int64_t>("normalize_routing_weights", 0) == 1;

use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault<int64_t>("use_sparse_mixer", 0) == 1;
if (use_sparse_mixer_) {
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
}
}

bool normalize_routing_weights_;
bool use_sparse_mixer_;
int64_t k_;
ActivationType activation_type_;
};

} // namespace contrib
} // namespace onnxruntime
94 changes: 94 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/moe/moe_utils.h"
#include <cmath>
#include <algorithm>

namespace onnxruntime {
namespace contrib {

float ApplyActivation(float x, ActivationType activation_type) {
switch (activation_type) {
case ActivationType::Relu:
return std::max(0.0f, x);
case ActivationType::Gelu:
return 0.5f * x * (1.0f + std::tanh(0.7978845608f * (x + 0.044715f * x * x * x)));
case ActivationType::Silu:
return x * (1.0f / (1.0f + std::exp(-x)));
case ActivationType::Identity:
return x;
case ActivationType::SwiGLU:
// SwiGLU: This is handled specially as it requires gating, not applied here
return x;
default:
return x; // Default to identity
}
}

// Helper method for applying SwiGLU activation with different memory layouts
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) {
constexpr float swiglu_alpha = 1.702f;
constexpr float clamp_limit = 7.0f; // Clamping limit as specified

if (is_interleaved_format) {
// For interleaved format [linear, gate, linear, gate, ...], process directly
// Make a temporary copy of each pair of values before modifying them
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
const size_t linear_idx = 2 * idx;
const size_t gate_idx = linear_idx + 1;

// Store original values
float linear_val = data[linear_idx]; // Interleaved: even index
float gate_val = data[gate_idx]; // Interleaved: odd index

// Apply clamping to the values
if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only
if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
float result = swish_out * (linear_val + 1.0f);

// Store result in first element (linear position)
data[idx] = result;
}
} else {
// For chunked layout [linear..., gate...], handle separately
// Need to work with original data in-place
// First, store all the gate computations since they depend on original gate values
std::vector<float> computed_gates(static_cast<size_t>(inter_size));

Check warning on line 64 in onnxruntime/contrib_ops/cpu/moe/moe_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/moe/moe_utils.cc:64: Add #include <vector> for vector<> [build/include_what_you_use] [4]

for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float gate_val = data[idx + static_cast<size_t>(inter_size)];

// Apply clamping to the gate value (max only)
if (gate_val > clamp_limit) gate_val = clamp_limit;

// Compute the gate part of SwiGLU
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
computed_gates[idx] = gate_val * sigmoid_out;
}

// Now apply the full activation with the precomputed gate values
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float linear_val = data[idx];

// Apply clamping to the linear value (min/max)
if (linear_val > clamp_limit) linear_val = clamp_limit;
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

data[idx] = computed_gates[idx] * (linear_val + 1.0f);
}
}
}

} // namespace contrib
} // namespace onnxruntime
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cstdint>
#include "contrib_ops/cpu/moe/moe_base_cpu.h"

namespace onnxruntime {
namespace contrib {

float ApplyActivation(float x, ActivationType activation_type);
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format);

} // namespace contrib
} // namespace onnxruntime
Loading
Loading