Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix
  • Loading branch information
apsonawane committed Aug 1, 2025
commit 9fdb2ff9c38fbe2abc9cf7cc7f2074a1456e2338
18 changes: 9 additions & 9 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,39 @@ float ApplyActivation(float x, ActivationType activation_type) {
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) {
constexpr float swiglu_alpha = 1.702f;
// Create a temporary buffer for the result
auto result_buffer = std::make_unique<float[]>(inter_size);
auto result_buffer = std::make_unique<float[]>(static_cast<size_t>(inter_size));

if (is_interleaved_format) {
// For interleaved format [linear, gate, linear, gate, ...], process directly
for (int64_t i = 0; i < inter_size; ++i) {
float linear_val = data[2 * i]; // Interleaved: even index
float gate_val = data[2 * i + 1]; // Interleaved: odd index
float linear_val = data[2 * static_cast<size_t>(i)]; // Interleaved: even index
float gate_val = data[2 * static_cast<size_t>(i) + 1]; // Interleaved: odd index

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
result_buffer[i] = swish_out * (linear_val + 1.0f);
result_buffer[static_cast<size_t>(i)] = swish_out * (linear_val + 1.0f);
}
} else {
// For chunked layout [linear..., gate...], handle separately
float* linear_part = data;
float* gate_part = data + inter_size;
float* gate_part = data + static_cast<size_t>(inter_size);

for (int64_t i = 0; i < inter_size; ++i) {
float linear_val = linear_part[i];
float gate_val = gate_part[i];
float linear_val = linear_part[static_cast<size_t>(i)];
float gate_val = gate_part[static_cast<size_t>(i)];

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
result_buffer[i] = swish_out * (linear_val + 1.0f);
result_buffer[static_cast<size_t>(i)] = swish_out * (linear_val + 1.0f);
}
}

// Copy result back to data (first inter_size elements only - rest is overwritten by GEMM)
std::memcpy(data, result_buffer.get(), inter_size * sizeof(float));
std::memcpy(data, result_buffer.get(), static_cast<size_t>(inter_size) * sizeof(float));
}

} // namespace contrib
Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/test/python/transformers/test_qmoe_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,11 @@ def __init__(self, quant_bits=0, onnx_dtype=None):
self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32

def create_ort_session(self, moe_onnx_graph):
from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415

sess_options = SessionOptions()
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 2

try:
ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider)
ort_session = onnxruntime.InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider)
except Exception as e:
print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}")
print("Skipping ONNX Runtime execution for this test case.")
Expand Down