Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9aad21c
DynamicQuantizeMatMul - handle case where B zero point input is provi…
edgchen1 Jul 26, 2025
b214da5
upgrade emsdk to v4.0.11 (#25477)
fs-eire Jul 26, 2025
7c0c29d
[build] Fix the file copy in get_docker_image.py (#25548)
fs-eire Jul 26, 2025
1b584c1
[webgpu] Enable per-run control for graph capture (#25367)
qjia7 Jul 27, 2025
51d3198
Refactor plugin EP support (#25541)
skottmckay Jul 28, 2025
2e0f717
Remove the python installation steps from win-qnn-arm64-ci-pipeline.y…
snnn Jul 28, 2025
413d38d
[QNN EP] Support more Einsum equation: bhwc,hkc->bhwk (#25518)
qti-yuduo Jul 28, 2025
f80697a
Cherry-pick round 1 (#25563)
snnn Jul 28, 2025
bac8af3
Upgrade xnnpack to latest (#25275)
fanchenkong1 Jul 28, 2025
38e660c
Fix webgpu_pix_frame_generator by adding missing present mode attribu…
shaoboyan091 Jul 28, 2025
a2b4546
[CUDA] Support SwiGlu in MoE and qMoE (#25530)
tianleiwu Jul 28, 2025
6ee4ea3
Fix C/C++ documentation generation (#25569)
adrianlizarraga Jul 29, 2025
87f1499
[MIGraphX EP] Fix compilation after cherry-picking from win-onnxrunti…
apwojcik Jul 29, 2025
2bd00ec
[webgpu] Optimize FlashAttention for prefill (#25395)
daijh Jul 29, 2025
8f20e30
[EP ABI] Node_GetAttrByName returns ORT_NOT_FOUND with non-existing a…
adrianlizarraga Jul 29, 2025
f53d7d8
[QNN-EP] Fix data type check to skip optional I/Os (#25514)
qti-yuduo Jul 29, 2025
a89b038
[build] upgrade Node.js for NPM packaging pipeline (#25568)
fs-eire Jul 29, 2025
c22f70d
[QNN-EP] Einsum equation ReduceSum Multiply on broadcast X (#25581)
qti-yuduo Jul 30, 2025
b957547
[build] fix multi-config for VCPKG (#25585)
fs-eire Jul 30, 2025
131cf40
Update OrtEpFactory in MiGraphX EP (#25567)
psakhamoori Jul 30, 2025
c29737d
[webgpu] use u32 to represent f16 in uniform (#25391)
fs-eire Jul 30, 2025
4a8a289
[EP ABI] Support for TENSOR type attribute (#25566)
chilo-ms Jul 30, 2025
f91d24c
[build] fix build break on arm (#25601)
fs-eire Jul 31, 2025
eade5fe
Add WhereDummyDq Transformer to form Node Unit (#25576)
qti-hungjuiw Jul 31, 2025
e753643
add session_id_ to LogEvaluationStart/Stop, LogSessionCreationStart (…
xieofxie Jul 31, 2025
866c7e3
[VitisAI] add new api to VitisAI to save graph as a string (#25602)
yifei410 Jul 31, 2025
68b9d9b
[CUDA] BF16 MoE and qMoE (#25572)
tianleiwu Jul 31, 2025
780c0e1
Disable Turing GPU for Nv Trt Rtx Ep (#25611)
ishwar-raut1 Jul 31, 2025
5c0a7d8
[QNN EP] Add Unit tests for LPBQ Fusions (#25592)
quic-tirupath Jul 31, 2025
ac7af24
Refactor Java Test Pipeline (#25608)
snnn Jul 31, 2025
a7bc727
Cache opSupportLimits to improve the performance and update tracing e…
qwu16 Jul 31, 2025
7b2f667
[build] disable CodeQL for NPM Packaging Pipeline (#25614)
fs-eire Aug 1, 2025
e57dc2a
[QNN EP] Lower Gemm with 2d bias to FC + ElementwiseAdd when targetin…
quic-muchhsu Aug 1, 2025
1109d03
[VitisAI] bugfix model_clone optimization
mingyueliuh Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[QNN-EP] Einsum equation ReduceSum Multiply on broadcast X (#25581)
[QNN-EP] Einsum equation ReduceSum Multiply on broadcast X
  • Loading branch information
qti-yuduo authored Jul 30, 2025
commit c22f70d48e012de2a480d6ce2ef9920f320f3f22
169 changes: 168 additions & 1 deletion onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,50 @@ bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) {
return true;
}

bool IsEquationReduceSumMulBroadcastX(const Equation& equation) {
// E.g., bhwc,wkc->bhwk
const auto& [term_1, term_2, result] = equation;
if (term_1.size() != 4) {
return false;
}
if (term_2.size() != 3) {
return false;
}
if (result.size() != 4) {
return false;
}

// Check contraction over last axis (c)
char c1 = term_1[3];
char c2 = term_2[2];
if (c1 != c2) {
return false;
}

// Check w axis alignment
if (term_1[2] != term_2[0]) {
return false;
}
if (term_1[2] != result[2]) {
return false;
}

// Check k axis alignment
if (term_2[1] != result[3]) {
return false;
}

// Check batch dimensions
if (term_1[0] != result[0]) {
return false;
}
if (term_1[1] != result[1]) {
return false;
}

return true;
}

/**
* @brief Sets the parameter tensor names for a MatMul op.
*
Expand Down Expand Up @@ -305,6 +349,113 @@ Status CreateMatMulTransposeAll(
return Status::OK();
}

/**
* @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y.
*
* @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model.
* @param node_unit The NodeUnit representing the ONNX node to be converted.
* @param do_op_validation A boolean flag indicating whether to perform operation validation.
* @return Status indicating success or failure of the operation.
*/
Status CreateReduceSumMulBroadcastX(
onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper,
const onnxruntime::NodeUnit& node_unit,
std::vector<std::string>&& input_names,
bool do_op_validation) {
// Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'.
// Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce.
onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{};
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], tensor_info_in0));
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], tensor_info_in1));
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Outputs()[0], tensor_info_out));
const std::vector<uint32_t>& shape_in0 = tensor_info_in0.shape;
const std::vector<uint32_t>& shape_in1 = tensor_info_in1.shape;
ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4");
ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3");
const std::vector<uint32_t> new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]};
const std::string reshape_out_name = input_names[0] + "_reshaped";
ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode(
/*input_name=*/input_names[0],
/*output_name=*/reshape_out_name,
/*input_shape=*/shape_in0,
/*output_shape=*/new_shape_in0,
/*tensor_data_type=*/tensor_info_in0.qnn_data_type,
/*quantize_param=*/tensor_info_in0.quant_param.Copy(),
/*do_op_validation=*/do_op_validation,
/*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0])));

// Multiply: reshaped in0 * in1
// The output shape of the multiplication is determined by broadcasting the reshaped in0 of
// (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c).
const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul";
std::vector<uint32_t> shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]};
onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name,
QNN_TENSOR_TYPE_NATIVE,
tensor_info_in0.qnn_data_type,
tensor_info_in0.quant_param.Copy(),
std::move(shape_out_mul));
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)),
"CreateReduceSumMulBroadcastX: failed to AddTensorWrapper");
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(
/*qnn_node_name=*/mul_out_name,
/*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW,
/*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY,
/*input_names=*/{reshape_out_name, input_names[1]},
/*output_names=*/{mul_out_name},
/*param_tensor_names=*/{},
/*do_op_validation=*/do_op_validation),
"CreateReduceSumMulBroadcastX: failed to create Mul node");

std::vector<std::string> param_tensor_names{};

// ReduceSum on last axes={4}, keep_dims=False
// Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c),
// which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk).
std::vector<uint32_t> axes_shape{SafeInt<uint32_t>(1)};
std::vector<uint32_t> axes_value{SafeInt<uint32_t>(4)};
onnxruntime::qnn::QnnParamWrapper param_axes(node_unit.Index(),
node_unit.Name(),
QNN_OP_REDUCE_SUM_PARAM_AXES,
std::move(axes_shape),
std::move(axes_value));
param_tensor_names.push_back(param_axes.GetParamTensorName());
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_axes)),
"CreateReduceSumMulBroadcastX: failed to add param axes");

Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT;
keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8;
keep_dims_scalar.bool8Value = SafeInt<uint8_t>(0);
onnxruntime::qnn::QnnParamWrapper param_keep_dims(node_unit.Index(),
node_unit.Name(),
QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS,
keep_dims_scalar);
param_tensor_names.push_back(param_keep_dims.GetParamTensorName());
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_keep_dims)),
"CreateReduceSumMulBroadcastX: failed to add param keep_dims");

const std::string out_name = node_unit.Outputs()[0].node_arg.Name();
Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput(out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out(out_name,
out_tensor_type,
tensor_info_out.qnn_data_type,
tensor_info_out.quant_param.Copy(),
std::move(tensor_info_out.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_out)),
"CreateReduceSumMulBroadcastX: failed to AddTensorWrapper");

ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(
/*qnn_node_name=*/out_name,
/*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW,
/*qnn_node_type=*/QNN_OP_REDUCE_SUM,
/*input_names=*/{mul_out_name},
/*output_names=*/{out_name},
/*param_tensor_names=*/std::move(param_tensor_names),
/*do_op_validation=*/do_op_validation),
"CreateReduceSumMulBroadcastX: failed to create ReduceSum node");

return Status::OK();
}

} // namespace

namespace onnxruntime {
Expand Down Expand Up @@ -356,9 +507,20 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
if (!IsEquationMatMul(parsed_equation.value()) &&
!IsEquationMatMulTransposeY(parsed_equation.value()) &&
!IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) &&
!IsEquationMatMulTransposeAll(parsed_equation.value())) {
!IsEquationMatMulTransposeAll(parsed_equation.value()) &&
!IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation);
}
if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) {
// QAIRT 3.36.1: Failed to validate on GPU.
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " on backend GPU");
}
if (node_unit.Inputs()[0].quant_param.has_value()) {
// QAIRT 3.36.1: Failed to finalize QNN graph 1002.
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " for quantized inputs");
}
}
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
}

Expand Down Expand Up @@ -408,6 +570,11 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
/*node_unit=*/node_unit,
/*input_names=*/std::move(input_names),
/*do_op_validation=*/do_op_validation));
} else if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
ORT_RETURN_IF_ERROR(CreateReduceSumMulBroadcastX(/*qnn_model_wrapper=*/&qnn_model_wrapper,
/*node_unit=*/node_unit,
/*input_names=*/std::move(input_names),
/*do_op_validation=*/do_op_validation));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation);
}
Expand Down
61 changes: 57 additions & 4 deletions onnxruntime/test/providers/qnn/einsum_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) {
/*tolerance=*/1e-4f);
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
const std::vector<int64_t> shape0{1, 7, 1, 7};
const std::vector<int64_t> shape1{1, 9, 1, 7};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeCpu,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bkhq,bchk->bchq",
/*tolerance=*/1e-4f);
}

TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) {
const std::vector<int64_t> shape0{2, 3, 3, 4};
const std::vector<int64_t> shape1{3, 3, 4};
Expand All @@ -202,16 +215,16 @@ TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) {
/*tolerance=*/1e-4f);
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
const std::vector<int64_t> shape0{1, 7, 1, 7};
const std::vector<int64_t> shape1{1, 9, 1, 7};
TEST_F(QnnCPUBackendTests, EinsumReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{2, 3, 4, 5};
const std::vector<int64_t> shape1{4, 6, 5};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeCpu,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bkhq,bchk->bchq",
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/1e-4f);
}

Expand Down Expand Up @@ -299,6 +312,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) {
/*tolerance=*/1e-2f);
}

TEST_F(QnnHTPBackendTests, EinsumF16ReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{1, 3, 2, 4};
const std::vector<int64_t> shape1{2, 3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeHtp,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/1e-2f);
}

//
// QNN HTP QDQ
//
Expand Down Expand Up @@ -375,6 +401,19 @@ TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) {
/*tolerance=*/QDQTolerance());
}

// TODO: Re-enable. QAIRT 3.36.1: failed to finalize QNN graph 1002.
TEST_F(QnnHTPBackendTests, DISABLED_EinsumQdqReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{1, 3, 2, 4};
const std::vector<int64_t> shape1{2, 3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/QDQTolerance());
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

#if defined(_M_ARM64)
Expand Down Expand Up @@ -474,6 +513,20 @@ TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) {
/*tolerance=*/1e-4f);
}

// TODO: Re-enable. Failed on QAIRT 3.36.1.
TEST_F(QnnGPUBackendTests, DISABLED_EinsumReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{1, 3, 2, 4};
const std::vector<int64_t> shape1{2, 3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeGpu,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/1e-4f);
}

#endif // defined(_M_ARM64) GPU tests

} // namespace test
Expand Down
Loading