Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Some fixes
  • Loading branch information
apsonawane committed Sep 3, 2025
commit 63f595adaa3a6990b36a22c40653b3a93eff6ecb
115 changes: 54 additions & 61 deletions onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Status MoE::MoEImpl(OpKernelContext* context,

// Find top-k experts for this row
std::vector<std::pair<float, int64_t>> expert_scores;
expert_scores.reserve(num_experts); // Pre-allocate to avoid reallocations
for (int64_t expert = 0; expert < num_experts; ++expert) {
expert_scores.emplace_back(current_router[expert], expert);
}
Expand All @@ -122,10 +123,16 @@ Status MoE::MoEImpl(OpKernelContext* context,
for (int64_t i = 0; i < k_; ++i) {
total_weight += expert_scores[i].first;
}
if (total_weight > 0.0f) {
// Check for numerical stability - avoid division by very small numbers
if (total_weight > 1e-8f) {
for (int64_t i = 0; i < k_; ++i) {
expert_scores[i].first /= total_weight;
}
} else {
// If total weight is too small, set uniform weights
for (int64_t i = 0; i < k_; ++i) {
expert_scores[i].first = 1.0f / static_cast<float>(k_);
}
}
}

Expand All @@ -137,6 +144,12 @@ Status MoE::MoEImpl(OpKernelContext* context,

if (weight <= 0.0f) continue;

// Validate expert index to prevent out-of-bounds access
if (expert_idx < 0 || expert_idx >= num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Expert index out of bounds: ", expert_idx, " (valid range: 0-", num_experts - 1, ")");
}

// Calculate weight offsets based on layout
const float* fc1_expert_weights;
const float* fc2_expert_weights;
Expand All @@ -158,10 +171,10 @@ Status MoE::MoEImpl(OpKernelContext* context,
std::fill(expert_output.begin(), expert_output.end(), 0.0f);

// Process this expert
ProcessExpert(current_input, fc1_expert_weights, fc1_expert_bias,
fc2_expert_weights, fc2_expert_bias,
expert_output.data(), hidden_size, inter_size, fc1_inter_size,
legacy_shape);
ORT_RETURN_IF_ERROR(ProcessExpert(current_input, fc1_expert_weights, fc1_expert_bias,
fc2_expert_weights, fc2_expert_bias,
expert_output.data(), hidden_size, inter_size, fc1_inter_size,
legacy_shape));

// Accumulate weighted expert output
for (int64_t j = 0; j < hidden_size; ++j) {
Expand All @@ -173,56 +186,54 @@ Status MoE::MoEImpl(OpKernelContext* context,
return Status::OK();
}

void MoE::ProcessExpert(const float* input_data,
const float* fc1_weights,
const float* fc1_bias,
const float* fc2_weights,
const float* fc2_bias,
float* output_data,
int64_t hidden_size,
int64_t inter_size,
int64_t fc1_inter_size,
bool legacy_shape) const {
Status MoE::ProcessExpert(const float* input_data,
const float* fc1_weights,
const float* fc1_bias,
const float* fc2_weights,
const float* fc2_bias,
float* output_data,
int64_t hidden_size,
int64_t inter_size,
int64_t fc1_inter_size,
bool legacy_shape) const {
const bool is_swiglu = (activation_type_ == ActivationType::SwiGLU);

// DEBUG: Add logging for SwiGLU debugging
if (is_swiglu) {
printf("DEBUG MoE Kernel: Processing SwiGLU expert - hidden_size=%lld, inter_size=%lld, fc1_inter_size=%lld\n",
(long long)hidden_size, (long long)inter_size, (long long)fc1_inter_size);
printf("DEBUG MoE Kernel: activation_alpha_=%f, activation_beta_=%f, swiglu_limit_=%f\n",
activation_alpha_, activation_beta_, swiglu_limit_);
printf("DEBUG MoE Kernel: Input sample: [%.6f, %.6f, %.6f]\n",
input_data[0], input_data[1], input_data[2]);
}

// Allocate intermediate buffer
std::vector<float> fc1_output(fc1_inter_size);

// Validate buffer sizes to prevent memory corruption
if (fc1_inter_size <= 0 || hidden_size <= 0 || inter_size <= 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid tensor dimensions: hidden_size=", hidden_size,
", inter_size=", inter_size, ", fc1_inter_size=", fc1_inter_size);
}

// FC1: input -> intermediate using MLAS GEMM for better performance
// For legacy layout: weights are [hidden_size, fc1_inter_size], stored row-major
// For new layout: weights are [fc1_inter_size, hidden_size], stored row-major
// MLAS expects: C = A * B^T + bias (if beta=0, bias is ignored in GEMM)

MLAS_SGEMM_DATA_PARAMS fc1_params;
fc1_params.A = input_data; // input: [1, hidden_size]
fc1_params.lda = hidden_size;
fc1_params.A = input_data; // input: [1, hidden_size]
fc1_params.lda = hidden_size; // leading dimension of A (single row, so stride = hidden_size)
fc1_params.alpha = 1.0f;
fc1_params.beta = 0.0f;
fc1_params.C = fc1_output.data();
fc1_params.ldc = fc1_inter_size; // leading dimension of C (single row, so stride = fc1_inter_size)

if (legacy_shape) {
// Legacy: weights [hidden_size, fc1_inter_size] -> need transpose for GEMM
// A[1, hidden_size] * B^T[hidden_size, fc1_inter_size] = C[1, fc1_inter_size]
// Legacy: weights [hidden_size, fc1_inter_size] stored row-major
// GEMM: A[1, hidden_size] * B^T[hidden_size, fc1_inter_size] = C[1, fc1_inter_size]
// Before transpose, B is [hidden_size, fc1_inter_size] with ldb = fc1_inter_size
fc1_params.B = fc1_weights;
fc1_params.ldb = fc1_inter_size; // leading dimension of B before transpose
fc1_params.ldc = fc1_inter_size;
MlasGemm(CblasNoTrans, CblasTrans, 1, fc1_inter_size, hidden_size, fc1_params, nullptr);
} else {
// New: weights [fc1_inter_size, hidden_size] -> no transpose needed
// A[1, hidden_size] * B^T[fc1_inter_size, hidden_size] = C[1, fc1_inter_size]
// New: weights [fc1_inter_size, hidden_size] stored row-major
// GEMM: A[1, hidden_size] * B^T[fc1_inter_size, hidden_size] = C[1, fc1_inter_size]
// Before transpose, B is [fc1_inter_size, hidden_size] with ldb = hidden_size
fc1_params.B = fc1_weights;
fc1_params.ldb = hidden_size; // leading dimension of B before transpose
fc1_params.ldc = fc1_inter_size;
MlasGemm(CblasNoTrans, CblasTrans, 1, fc1_inter_size, hidden_size, fc1_params, nullptr);
}

Expand All @@ -244,49 +255,38 @@ void MoE::ProcessExpert(const float* input_data,
fc2_input_buffer.resize(inter_size);
fc2_input = fc2_input_buffer.data();

// DEBUG: Check FC1 output before SwiGLU
printf("DEBUG MoE Kernel: FC1 output sample before SwiGLU: [%.6f, %.6f, %.6f, %.6f]\n",
fc1_output[0], fc1_output[1], fc1_output[2], fc1_output[3]);

// Apply SwiGLU activation: transform fc1_output[2*inter_size] -> fc2_input[inter_size]
ApplySwiGLUActivation(fc1_output.data(), fc2_input, inter_size, true,
activation_alpha_, activation_beta_, swiglu_limit_);

// DEBUG: Check FC2 input after SwiGLU
printf("DEBUG MoE Kernel: FC2 input sample after SwiGLU: [%.6f, %.6f, %.6f]\n",
fc2_input[0], fc2_input[1], fc2_input[2]);
} else {
ApplyActivationInPlace(fc1_output.data(), fc1_inter_size);

// DEBUG: Check activation output for SiLU
printf("DEBUG MoE Kernel: After SiLU activation: [%.6f, %.6f, %.6f]\n",
fc1_output[0], fc1_output[1], fc1_output[2]);
}

// FC2: intermediate -> output using MLAS GEMM
// FC2 input is either the activated fc1_output (non-SwiGLU) or fc2_input_buffer (SwiGLU)
// Both have size inter_size

MLAS_SGEMM_DATA_PARAMS fc2_params;
fc2_params.A = fc2_input; // intermediate: [1, inter_size]
fc2_params.lda = inter_size;
fc2_params.A = fc2_input; // intermediate: [1, inter_size]
fc2_params.lda = inter_size; // leading dimension of A (single row, so stride = inter_size)
fc2_params.alpha = 1.0f;
fc2_params.beta = 0.0f;
fc2_params.C = output_data;
fc2_params.ldc = hidden_size; // leading dimension of C (single row, so stride = hidden_size)

if (legacy_shape) {
// Legacy: weights [inter_size, hidden_size] -> need transpose for GEMM
// A[1, inter_size] * B^T[inter_size, hidden_size] = C[1, hidden_size]
// Legacy: weights [inter_size, hidden_size] stored row-major
// GEMM: A[1, inter_size] * B^T[inter_size, hidden_size] = C[1, hidden_size]
// Before transpose, B is [inter_size, hidden_size] with ldb = hidden_size
fc2_params.B = fc2_weights;
fc2_params.ldb = hidden_size; // leading dimension of B before transpose
fc2_params.ldc = hidden_size;
MlasGemm(CblasNoTrans, CblasTrans, 1, hidden_size, inter_size, fc2_params, nullptr);
} else {
// New: weights [hidden_size, inter_size] -> no transpose needed
// A[1, inter_size] * B^T[hidden_size, inter_size] = C[1, hidden_size]
// New: weights [hidden_size, inter_size] stored row-major
// GEMM: A[1, inter_size] * B^T[hidden_size, inter_size] = C[1, hidden_size]
// Before transpose, B is [hidden_size, inter_size] with ldb = inter_size
fc2_params.B = fc2_weights;
fc2_params.ldb = inter_size; // leading dimension of B before transpose
fc2_params.ldc = hidden_size;
MlasGemm(CblasNoTrans, CblasTrans, 1, hidden_size, inter_size, fc2_params, nullptr);
}

Expand All @@ -297,14 +297,7 @@ void MoE::ProcessExpert(const float* input_data,
}
}

// DEBUG: Final output sample
if (is_swiglu) {
printf("DEBUG MoE Kernel: Final SwiGLU output: [%.6f, %.6f, %.6f]\n",
output_data[0], output_data[1], output_data[2]);
} else {
printf("DEBUG MoE Kernel: Final SiLU output: [%.6f, %.6f, %.6f]\n",
output_data[0], output_data[1], output_data[2]);
}
return Status::OK();
}

void MoE::ApplyActivationInPlace(float* data, int64_t size, bool is_swiglu_format) const {
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/contrib_ops/cpu/moe/moe_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ class MoE final : public OpKernel, public MoEBaseCPU {
const Tensor* fc3_experts_bias,
Tensor* output) const;

void ProcessExpert(const float* input_data,
const float* fc1_weights,
const float* fc1_bias,
const float* fc2_weights,
const float* fc2_bias,
float* output_data,
int64_t hidden_size,
int64_t inter_size,
int64_t fc1_inter_size,
bool legacy_shape) const;
Status ProcessExpert(const float* input_data,
const float* fc1_weights,
const float* fc1_bias,
const float* fc2_weights,
const float* fc2_bias,
float* output_data,
int64_t hidden_size,
int64_t inter_size,
int64_t fc1_inter_size,
bool legacy_shape) const;

void ApplyActivationInPlace(float* data, int64_t size, bool is_swiglu_format = false) const;
};
Expand Down
64 changes: 32 additions & 32 deletions onnxruntime/test/python/transformers/test_qmoe_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False
if hasattr(self, "quant_bits") and self.quant_bits > 0:
# QMoE: Pass raw logits directly (QMoE does softmax internally)
router_input = router_logits
print("DEBUG: Using QMoE routing (raw logits)")
# print("DEBUG: Using QMoE routing (raw logits)")
else:
# Regular MoE: Apply the same routing logic as PyTorch reference
# This converts raw logits to proper routing probabilities
Expand All @@ -700,12 +700,12 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False
expert_idx = selected_experts[i, j]
router_input[i, expert_idx] = routing_weights[i, j]

print("DEBUG: Using regular MoE routing (processed probabilities)")
# print("DEBUG: Using regular MoE routing (processed probabilities)")

print(f"DEBUG: router_input stats: mean={router_input.mean():.6f}, std={router_input.std():.6f}")
print(
f"DEBUG: hidden_states_flat stats: mean={hidden_states_flat.mean():.6f}, std={hidden_states_flat.std():.6f}"
)
# print(f"DEBUG: router_input stats: mean={router_input.mean():.6f}, std={router_input.std():.6f}")
# print(
# f"DEBUG: hidden_states_flat stats: mean={hidden_states_flat.mean():.6f}, std={hidden_states_flat.std():.6f}"
# )

torch_dtype = onnx_to_torch_type_map[self.onnx_dtype]

Expand Down Expand Up @@ -738,13 +738,13 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False
buffer_ptr=tensor.data_ptr(),
)

print("DEBUG: About to run ORT inference...")
# print("DEBUG: About to run ORT inference...")

iobinding.synchronize_inputs()
self.ort_sess.run_with_iobinding(iobinding)
iobinding.synchronize_outputs()

print("DEBUG: ORT inference completed successfully")
# print("DEBUG: ORT inference completed successfully")

if enable_performance_test:
repeat = 100
Expand Down Expand Up @@ -1290,28 +1290,28 @@ def parity_check(self):

# Debug: Print output statistics
activation_type = "SwiGLU" if self.use_swiglu else "SiLU"
print(
f"DEBUG - {activation_type}: torch_output stats: mean={torch_output.mean():.6f}, std={torch_output.std():.6f}, min={torch_output.min():.6f}, max={torch_output.max():.6f}"
)
print(
f"DEBUG - {activation_type}: ort_output stats: mean={ort_output.mean():.6f}, std={ort_output.std():.6f}, min={ort_output.min():.6f}, max={ort_output.max():.6f}"
)
# print(
# f"DEBUG - {activation_type}: torch_output stats: mean={torch_output.mean():.6f}, std={torch_output.std():.6f}, min={torch_output.min():.6f}, max={torch_output.max():.6f}"
# )
# print(
# f"DEBUG - {activation_type}: ort_output stats: mean={ort_output.mean():.6f}, std={ort_output.std():.6f}, min={ort_output.min():.6f}, max={ort_output.max():.6f}"
# )

# Debug: Check if tensors are sharing memory (for SwiGLU bug investigation)
if self.use_swiglu:
print("DEBUG - SwiGLU Memory Check:")
print(f" torch_output.data_ptr() = {torch_output.data_ptr()}")
print(f" ort_output.data_ptr() = {ort_output.data_ptr()}")
print(f" torch_output is ort_output = {torch_output is ort_output}")
print(
f" torch_output.shares_memory_with(ort_output) = {torch_output.storage().data_ptr() == ort_output.storage().data_ptr()}"
)
# print("DEBUG - SwiGLU Memory Check:")
# print(f" torch_output.data_ptr() = {torch_output.data_ptr()}")
# print(f" ort_output.data_ptr() = {ort_output.data_ptr()}")
# print(f" torch_output is ort_output = {torch_output is ort_output}")
# print(
# f" torch_output.shares_memory_with(ort_output) = {torch_output.storage().data_ptr() == ort_output.storage().data_ptr()}"
# )

# Check first few values for bit-for-bit comparison
torch_sample = torch_output.flatten()[:10]
ort_sample = ort_output.flatten()[:10]
print(f" torch_sample[:10] = {torch_sample.tolist()}")
print(f" ort_sample[:10] = {ort_sample.tolist()}")
# print(f" torch_sample[:10] = {torch_sample.tolist()}")
# print(f" ort_sample[:10] = {ort_sample.tolist()}")

# Force modification to check if they're linked
ort_output_modified = ort_output.clone()
Expand All @@ -1326,9 +1326,9 @@ def parity_check(self):
torch_has_inf = torch.isinf(torch_output).any()
ort_has_inf = torch.isinf(ort_output).any()

print(
f"DEBUG - {activation_type}: torch_has_nan={torch_has_nan}, ort_has_nan={ort_has_nan}, torch_has_inf={torch_has_inf}, ort_has_inf={ort_has_inf}"
)
# print(
# f"DEBUG - {activation_type}: torch_has_nan={torch_has_nan}, ort_has_nan={ort_has_nan}, torch_has_inf={torch_has_inf}, ort_has_inf={ort_has_inf}"
# )

if torch_has_nan or ort_has_nan or torch_has_inf or ort_has_inf:
torch_output_clean = torch.where(
Expand All @@ -1346,16 +1346,16 @@ def parity_check(self):
# if torch.equal(problematic_torch, problematic_ort):
# max_diff = 0.0

print(f"DEBUG - {activation_type}: max_diff after cleaning NaN/Inf = {max_diff:.6f}")
# print(f"DEBUG - {activation_type}: max_diff after cleaning NaN/Inf = {max_diff:.6f}")
else:
max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max()

# Debug: Show precise max_diff for SwiGLU case
if self.use_swiglu:
print(f"DEBUG - SwiGLU: Precise max_diff = {max_diff:.12f} (scientific: {max_diff:.6e})")
# Show a few actual differences
diff_tensor = (torch_output.cpu() - ort_output.cpu()).abs()
print(f"DEBUG - SwiGLU: Top 5 differences = {torch.topk(diff_tensor.flatten(), 5).values.tolist()}")
# if self.use_swiglu:
# print(f"DEBUG - SwiGLU: Precise max_diff = {max_diff:.12f} (scientific: {max_diff:.6e})")
# # Show a few actual differences
# diff_tensor = (torch_output.cpu() - ort_output.cpu()).abs()
# print(f"DEBUG - SwiGLU: Top 5 differences = {torch.topk(diff_tensor.flatten(), 5).values.tolist()}")

# Format output similar to SwiGLU tests
print(f"Parity check - {activation_type} 0-bit: max_diff = {max_diff:.6f}")
Expand Down