Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b3664f8
[CUDA] Support SwiGlu in MoE and qMoE (#25530)
tianleiwu Jul 28, 2025
a8e1186
[CUDA] BF16 MoE and qMoE (#25572)
tianleiwu Jul 31, 2025
a9f74a0
Add CUDA implementation of GatherBlockQuantized operator (#25575)
xiaomsft Aug 1, 2025
d83904b
Add support for QMoE in CPU (#25558)
apsonawane Aug 2, 2025
8654241
Update MoE and qMoE spec (#25619)
tianleiwu Aug 2, 2025
6ca2047
[CPU] Improve QMoE kernel (#25822)
apsonawane Aug 26, 2025
dd32daf
Fix MoE CPP tests (#25877)
apsonawane Aug 28, 2025
581b8e7
Add custom ops library_path to EP metadata (#25830)
psakhamoori Aug 29, 2025
a9308a1
[Fix] illegal memory access in GetInputIndices with optional inputs (…
mingyueliuh Aug 29, 2025
6c7f150
[TRT RTX EP] Add sync method (#25898)
gedoensmax Sep 2, 2025
535fcc6
[TRT RTX EP] Memory map the engine buffer (#25909)
gedoensmax Sep 3, 2025
1f4e581
[TRT RTX EP] Add support for RTX runtime caches (#25917)
gedoensmax Sep 3, 2025
9732a3e
Compile API: disable optimizations by default (#25474)
adrianlizarraga Sep 3, 2025
df25f45
[CXX] Introduce C++ API for new C entry points (#25897)
yuslepukhin Sep 3, 2025
8f587b1
Migrate model tests to ONNX Model ZOO only (#25888)
kobby-kobbs Sep 3, 2025
ab71f1e
Remove std::string::data() non-const usage from public headers (#25943)
yuslepukhin Sep 4, 2025
2d36f04
Compile API: output model and initializer stream write functions (#25…
adrianlizarraga Sep 4, 2025
c5096d9
[TRT RTX EP] Fixing the stream parameter in CopyTensors API and passi…
praneshgo Sep 4, 2025
5ee309e
[MLAS] Add 8-bit weights ARM64 Gemm implementation (#25110)
hariharans29 Sep 4, 2025
157df9c
[NV TensorRT RTX] Handle unsupported data types (#25953)
ishwar-raut1 Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[CPU] Improve QMoE kernel (#25822)
This pull request introduces several improvements and refactorings to
the quantized Mixture-of-Experts (QMoE) operator in ONNX Runtime,
focusing on enhanced support for FP32 mode, improved SwiGLU activation
handling, and better test coverage. The most important changes are
grouped below by theme.

### Operator Registration and Type Support

- Added explicit registration and support for `QMoE` operator with both
`MLFloat16` and `float` data types, enabling FP32 (non-quantized) mode
in addition to quantized modes. This includes updates to kernel
registration and schema/type constraints.
[[1]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L109-R110)
[[2]](diffhunk://#diff-fd949b2a9885f634c37c2048da9e35d227ed20adf1d7baf5de488f304a78bde9L275-R277)
[[3]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1467-R1467)
[[4]](diffhunk://#diff-81f57d9adc2cce94f85a2949a895b7ff82efcc13d05e23ee6567661f0fecb7c0L1548-R1548)

### SwiGLU Activation Improvements

- Refactored `ApplySwiGLUActivation` to accept configurable
`activation_alpha` and `activation_beta` parameters, matching CUDA
behavior and allowing flexibility in activation function tuning. Also,
dropped support for non-interleaved memory layouts (now not
implemented).
[[1]](diffhunk://#diff-4e4afb8dcdade0abe18bd8bea68b148b4090cd86d60a1b1422c049960231737dR49-R60)
[[2]](diffhunk://#diff-edb344a38502bba9a0083ab98e274ec1b5b2606639a61df7be474a600a7b99d2L29-R61)
[[3]](diffhunk://#diff-f85806c745243652a0336da094126687a6c0d14b19fe760abe73df1d940dc4cbL12-R13)
- Now reads `activation_alpha` and `activation_beta` attributes from
operator parameters, defaulting to values appropriate for SwiGLU.

### QMoE Operator Implementation Refactor

- Refactored the QMoE operator to clarify separation between quantized
and FP32 implementations, and restructured internal methods for better
maintainability. Added template parameterization for data types and
improved handling of expert weights and biases.
[[1]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5R13-R35)
[[2]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L38-R55)
[[3]](diffhunk://#diff-e54124baa488af74400fae0f0dbd5cf7d4f1e307c0a5ba0e9dc79622e1315cd5L58-L59)

### Shape Checking and Layout

- Removed legacy shape/layout support in QMoE input validation,
enforcing only the new memory layout for expert weights and improving
consistency and forward compatibility.

### Test and Documentation Updates

- Updated unit tests for QMoE to use correct zero-point values for
quantized weights (e.g., 0x88 for int4, 128 for int8), ensuring that
test cases accurately reflect expected zero-output behavior for zero
weights. Also clarified comments and expected outputs for SwiGLU and
quantized scenarios.
[[1]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1340-R1349)
[[2]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1379-R1380)
[[3]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1404-R1413)
[[4]](diffhunk://#diff-27ea1ef8d40401d116e653d6b935304a7ad68ee8300d04ea98e814c585abee75L1525-R1538)

These changes collectively improve the flexibility, correctness, and
maintainability of the QMoE operator in ONNX Runtime.


Unit test result
```
sRunning test: batch_size=1, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000372
.Running test: batch_size=1, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000392
.Running test: batch_size=1, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470
.Running test: batch_size=1, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442
.Running test: batch_size=4, sequence_length=8, quant_bits=4, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000470
.Running test: batch_size=4, sequence_length=8, quant_bits=8, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000442
.Running test: batch_size=4, sequence_length=32, quant_bits=4, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 4-bit: max_diff = 0.000609
.Running test: batch_size=4, sequence_length=32, quant_bits=8, use_swiglu=True, swiglu_interleaved=True
Parity check - SwiGLU(interleaved=True) 8-bit: max_diff = 0.000702
.
----------------------------------------------------------------------
Ran 9 tests in 46.754s

OK (skipped=1)
```

---------

Co-authored-by: Tianlei Wu <[email protected]>
  • Loading branch information
apsonawane and tianleiwu committed Sep 4, 2025
commit 6ca2047b44e70c06cfc2820bc29bc9c2bc0628ee
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE);
// ******** End: Quantization ******************* //

#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
Expand Down Expand Up @@ -272,7 +273,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/quantization/moe_helper.h"
#include "contrib_ops/cpu/moe/moe_helper.h"
#include <limits>

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -46,12 +47,21 @@ class MoEBaseCPU {
if (use_sparse_mixer_) {
ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2");
}

swiglu_fusion_ = op_kernel_info.GetAttrOrDefault<int64_t>("swiglu_fusion", 0);
swiglu_limit_ = op_kernel_info.GetAttrOrDefault<float>("swiglu_limit", std::numeric_limits<float>::infinity());
activation_alpha_ = op_kernel_info.GetAttrOrDefault<float>("activation_alpha", 1.0f);
activation_beta_ = op_kernel_info.GetAttrOrDefault<float>("activation_beta", 0.0f);
}

bool normalize_routing_weights_;
bool use_sparse_mixer_;
int64_t k_;
ActivationType activation_type_;
float activation_alpha_;
float activation_beta_;
float swiglu_limit_;
int64_t swiglu_fusion_;
};

} // namespace contrib
Expand Down
393 changes: 393 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/moe/moe_base_cpu.h"

namespace onnxruntime {
namespace contrib {

/**
* @brief QMoE is the templated CPU implementation of the Quantized Mixture of Experts operator.
*
* This kernel supports both float and MLFloat16 data types for activations, scales, and outputs.
* It parallelizes expert computation using the ONNX Runtime thread pool and minimizes memory
* usage through on-the-fly block dequantization of weights.
*
* @tparam T The data type for the kernel (float or MLFloat16).
*/
template <typename T>
class QMoECPU final : public OpKernel, public MoEBaseCPU {
public:
explicit QMoECPU(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* context) const override;

private:
int64_t expert_weight_bits_;
int64_t block_size_;
};

} // namespace contrib
} // namespace onnxruntime
66 changes: 12 additions & 54 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "contrib_ops/cpu/moe/moe_utils.h"
#include <cmath>
#include <algorithm>
#include "core/common/common.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -19,74 +20,31 @@ float ApplyActivation(float x, ActivationType activation_type) {
case ActivationType::Identity:
return x;
case ActivationType::SwiGLU:
// SwiGLU: This is handled specially as it requires gating, not applied here
// SwiGLU is a special case handled by ApplySwiGLUActivation, this is just a placeholder
return x;
default:
return x; // Default to identity
return x;
}
}

// Helper method for applying SwiGLU activation with different memory layouts
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) {
constexpr float swiglu_alpha = 1.702f;
constexpr float clamp_limit = 7.0f; // Clamping limit as specified

void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format,
float activation_alpha, float activation_beta, float clamp_limit) {
if (is_interleaved_format) {
// For interleaved format [linear, gate, linear, gate, ...], process directly
// Make a temporary copy of each pair of values before modifying them
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
const size_t linear_idx = 2 * idx;
const size_t gate_idx = linear_idx + 1;
float gate_val = input_data[2 * i];
float linear_val = input_data[2 * i + 1];

// Store original values
float linear_val = data[linear_idx]; // Interleaved: even index
float gate_val = data[gate_idx]; // Interleaved: odd index
gate_val = std::min(gate_val, clamp_limit);
linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit);

// Apply clamping to the values
if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only
if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_arg = activation_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
float result = swish_out * (linear_val + 1.0f);

// Store result in first element (linear position)
data[idx] = result;
output_data[i] = swish_out * (linear_val + activation_beta);
}
} else {
// For chunked layout [linear..., gate...], handle separately
// Need to work with original data in-place
// First, store all the gate computations since they depend on original gate values
std::vector<float> computed_gates(static_cast<size_t>(inter_size));

for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float gate_val = data[idx + static_cast<size_t>(inter_size)];

// Apply clamping to the gate value (max only)
if (gate_val > clamp_limit) gate_val = clamp_limit;

// Compute the gate part of SwiGLU
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
computed_gates[idx] = gate_val * sigmoid_out;
}

// Now apply the full activation with the precomputed gate values
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float linear_val = data[idx];

// Apply clamping to the linear value (min/max)
if (linear_val > clamp_limit) linear_val = clamp_limit;
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

data[idx] = computed_gates[idx] * (linear_val + 1.0f);
}
ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation");
}
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/moe/moe_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace onnxruntime {
namespace contrib {

float ApplyActivation(float x, ActivationType activation_type);
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format);

void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format,
float activation_alpha, float activation_beta, float clamp_limit);

} // namespace contrib
} // namespace onnxruntime
Loading