Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments
  • Loading branch information
apsonawane committed Aug 1, 2025
commit 0fcdc721382c44ec40b7abdeeabd1bf8c6bd5f9a
44 changes: 35 additions & 9 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,43 @@ float ApplyActivation(float x, ActivationType activation_type) {
}
}

void ApplySwiGLU(const float* fc1_output, float* result, int64_t inter_size) {
// Helper method for applying SwiGLU activation with different memory layouts
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) {
constexpr float swiglu_alpha = 1.702f;
for (int64_t i = 0; i < inter_size; ++i) {
float linear_val = fc1_output[2 * i]; // Interleaved: even index
float gate_val = fc1_output[2 * i + 1]; // Interleaved: odd index
// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
result[i] = swish_out * (linear_val + 1.0f);
// Create a temporary buffer for the result
auto result_buffer = std::make_unique<float[]>(inter_size);

if (is_interleaved_format) {
// For interleaved format [linear, gate, linear, gate, ...], process directly
for (int64_t i = 0; i < inter_size; ++i) {
float linear_val = data[2 * i]; // Interleaved: even index
float gate_val = data[2 * i + 1]; // Interleaved: odd index

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
result_buffer[i] = swish_out * (linear_val + 1.0f);
}
} else {
// For chunked layout [linear..., gate...], handle separately
float* linear_part = data;
float* gate_part = data + inter_size;

for (int64_t i = 0; i < inter_size; ++i) {
float linear_val = linear_part[i];
float gate_val = gate_part[i];

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
result_buffer[i] = swish_out * (linear_val + 1.0f);
}
}

// Copy result back to data (first inter_size elements only - rest is overwritten by GEMM)
std::memcpy(data, result_buffer.get(), inter_size * sizeof(float));
}

} // namespace contrib
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/moe/moe_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace onnxruntime {
namespace contrib {

float ApplyActivation(float x, ActivationType activation_type);
void ApplySwiGLU(const float* fc1_output, float* result, int64_t inter_size);
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format);

} // namespace contrib
} // namespace onnxruntime
335 changes: 153 additions & 182 deletions onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.cc

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cpu/quantization/moe_quantization_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ class QMoE final : public OpKernel, public MoEBaseCPU {
private:
template <bool UseUInt4x2>
Status PrepackAndDequantizeWeights(OpKernelContext* context,
MoEParameters& moe_params,
const Tensor* fc1_experts_weights,
const Tensor* fc2_experts_weights,
const Tensor* fc1_scales,
const Tensor* fc2_scales,
bool is_swiglu);
MoEParameters& moe_params,
const Tensor* fc1_experts_weights,
const Tensor* fc2_experts_weights,
const Tensor* fc1_scales,
const Tensor* fc2_scales,
bool is_swiglu);

template <bool UseUInt4x2, typename T>
Status QuantizedMoEImpl(OpKernelContext* context,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/contrib_ops/moe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1668,7 +1668,7 @@ TEST(MoETest, QMoETest_CPU_Float32) {
cpu_tester.AddOptionalInputEdge<float>(); // fc1_experts_bias
cpu_tester.AddInput<uint8_t>("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
cpu_tester.AddInput<float>("fc2_scales", fc2_scales_dims, fc2_scales);
cpu_tester.AddOptionalInputEdge<float>(); // fc2_experts_bias
cpu_tester.AddOptionalInputEdge<float>(); // fc2_experts_bias
cpu_tester.AddOptionalInputEdge<uint8_t>(); // fc3_experts_weights
cpu_tester.AddOptionalInputEdge<float>(); // fc3_scales
cpu_tester.AddOptionalInputEdge<float>(); // fc3_experts_bias
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/test/python/transformers/test_moe_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,5 @@ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits):
moe.benchmark_ort()



if __name__ == "__main__":
unittest.main()
Loading