Skip to content
Prev Previous commit
Next Next commit
improve backward compatible
  • Loading branch information
tianleiwu committed Aug 2, 2025
commit 1b72088f4086d7bfeef929acf756ab01e4ba368b
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ Status CheckInputs(MoEParameters& parameters,
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size;
const bool legacy_shape = hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size;

const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) ||
(hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size);

// Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one.
const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size;

if (legacy_shape) {
// legacy shape does not match the memory layout. This is for backward compatible
// legacy shape does not match column major memory layout. This is for backward compatibility.
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/test/contrib_ops/moe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) {

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // /2 for 4-bit
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // legacy shape
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2};
std::vector<int64_t> fc1_scales_dims = {num_experts, inter_size};
std::vector<int64_t> fc2_scales_dims = {num_experts, hidden_size};
Expand Down Expand Up @@ -1422,7 +1422,7 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) {

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // No /2 for 8-bit
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // legacy shape
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
std::vector<int64_t> fc1_scales_dims = {num_experts, inter_size};
std::vector<int64_t> fc2_scales_dims = {num_experts, hidden_size};
Expand Down Expand Up @@ -1474,7 +1474,7 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) {

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; // legacy shape
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2};
std::vector<int64_t> fc3_experts_weights_dims = {num_experts, hidden_size, inter_size / 2};
std::vector<int64_t> fc1_scales_dims = {num_experts, inter_size};
Expand Down Expand Up @@ -1543,8 +1543,8 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) {

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // 4-bit SwiGLU: stored as hidden x inter, but contains 2*inter data
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2};
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / 2};
std::vector<int64_t> fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU (linear + gate)
std::vector<int64_t> fc2_scales_dims = {num_experts, hidden_size};
std::vector<int64_t> output_dims = {num_rows, hidden_size};
Expand Down Expand Up @@ -1602,8 +1602,8 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size * 2}; // 8-bit SwiGLU: explicit 2x
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // 8-bit SwiGLU: explicit 2x
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, hidden_size, inter_size};
std::vector<int64_t> fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU
std::vector<int64_t> fc2_scales_dims = {num_experts, hidden_size};
std::vector<int64_t> output_dims = {num_rows, hidden_size};
Expand Down Expand Up @@ -1660,7 +1660,7 @@ TEST(MoETest, QMoETest_CPU_Float32) {

std::vector<int64_t> input_dims = {num_rows, hidden_size};
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size};
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; // legacy shape
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
std::vector<int64_t> fc1_scales_dims = {num_experts, inter_size};
std::vector<int64_t> fc2_scales_dims = {num_experts, hidden_size};
Expand Down
Loading