-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[MLAS] Add 8-bit weights ARM64 Gemm implementation #25110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…nnxruntime into hari/matmul8bits_arm
|
Pending final perf validation and accuracy verification post PR comments addressing |
|
Merging as the React Native pipeline failure is unrelated to this change |
### Description
Enable 8-bit weights Gemm on ARM64 via MLAS
1. Supports 2 flavors of the 8-bit Gemm kernel - one uses `vdotq` (U8U8)
and the other uses `vusdotq` (U8S8) on platforms where I8MM is
supported.
2. Provides access to these new MLAS Gemm kernels via the `MatmulNBits`
contrib operator
3. Tests:
**MLAS**
3 new sets of tests:
- `SQ8BitQuantA` : Tests the dynamic activation quantization MLAS kernel
(`fp32 -> uint8_t` or `fp32 -> int8_t` on I8MM platforms)
- `SQ8BitPrepack`: Tests the prepacking of the weights for the 8-bit
Gemm kernels
- `SQ8BitGemm`: Tests the 8-bit Gemm kernels
**MatmulNBits contrib tests**
- Enables the 8-bit Gemm tests on ARM64 (previously only enabled on x86)
### Motivation and Context
Enable 8-bit weights Gemm on ARM64 via MLAS
Based on work and contribution by @fajin-corp
Phi-4-mini-instruct perf numbers (before and after this change):
<img width="593" height="179" alt="image"
src="https://github.com/user-attachments/assets/d81b9059-b8db-407c-8c0f-527099f9358c"
/>
---------
Co-authored-by: Jing Fang <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
### Description Cherry-pick the following PRs: #25943 #25937 #25917 #25909 #25898 #25897 #25888 #25881 #25830 #25619 #25575 #25572 #25558 #25530 #25474 #25455 #25110 Also two dependent PRs for qMoE cpu: #25877 #25822 --------- Co-authored-by: xiaomsft <[email protected]> Co-authored-by: Xiaoyan Hu <[email protected]> Co-authored-by: Akshay Sonawane <[email protected]> Co-authored-by: Kunal Vaishnavi <[email protected]> Co-authored-by: Pradeep Sakhamoori <[email protected]> Co-authored-by: mingyue <[email protected]> Co-authored-by: Maximilian Müller <[email protected]> Co-authored-by: Adrian Lizarraga <[email protected]> Co-authored-by: Dmitri Smirnov <[email protected]> Co-authored-by: Emmanuel <[email protected]> Co-authored-by: Emmanuel Assumang <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: praneshgo <[email protected]> Co-authored-by: Hariharan Seshadri <[email protected]> Co-authored-by: Jing Fang <[email protected]> Co-authored-by: Ishwar Raut <[email protected]>
…ts buffer (#25971) ### Description The memory alignment for the pre-packed weights buffer was accidentally changed for 8-bit Gemms on x86 while supporting the ARM64 equivalent 8-bit Gemm kernel in #25110. This change in alignment could either cause perf penalty or seg-fault depending on the platform while the corresponding aligned data load instruction is executed in the Gemm kernel. This changes fixes it as well as adds back a couple of tests to the MLAS 8-bit Gemm test suite and fixes a minor nit in the test file. ### Motivation and Context Resolve packaging pipeline crash
…ts buffer (microsoft#25971) ### Description The memory alignment for the pre-packed weights buffer was accidentally changed for 8-bit Gemms on x86 while supporting the ARM64 equivalent 8-bit Gemm kernel in microsoft#25110. This change in alignment could either cause perf penalty or seg-fault depending on the platform while the corresponding aligned data load instruction is executed in the Gemm kernel. This changes fixes it as well as adds back a couple of tests to the MLAS 8-bit Gemm test suite and fixes a minor nit in the test file. ### Motivation and Context Resolve packaging pipeline crash (cherry picked from commit 96f4595)
…ts buffer (#25971) ### Description The memory alignment for the pre-packed weights buffer was accidentally changed for 8-bit Gemms on x86 while supporting the ARM64 equivalent 8-bit Gemm kernel in #25110. This change in alignment could either cause perf penalty or seg-fault depending on the platform while the corresponding aligned data load instruction is executed in the Gemm kernel. This changes fixes it as well as adds back a couple of tests to the MLAS 8-bit Gemm test suite and fixes a minor nit in the test file. ### Motivation and Context Resolve packaging pipeline crash (cherry picked from commit 96f4595)
…ts buffer (#25971) ### Description The memory alignment for the pre-packed weights buffer was accidentally changed for 8-bit Gemms on x86 while supporting the ARM64 equivalent 8-bit Gemm kernel in #25110. This change in alignment could either cause perf penalty or seg-fault depending on the platform while the corresponding aligned data load instruction is executed in the Gemm kernel. This changes fixes it as well as adds back a couple of tests to the MLAS 8-bit Gemm test suite and fixes a minor nit in the test file. ### Motivation and Context Resolve packaging pipeline crash (cherry picked from commit 96f4595)
Description
Enable 8-bit weights Gemm on ARM64 via MLAS
Supports 2 flavors of the 8-bit Gemm kernel - one uses
vdotq(U8U8) and the other usesvusdotq(U8S8) on platforms where I8MM is supported.Provides access to these new MLAS Gemm kernels via the
MatmulNBitscontrib operatorTests:
MLAS
3 new sets of tests:
SQ8BitQuantA: Tests the dynamic activation quantization MLAS kernel (fp32 -> uint8_torfp32 -> int8_ton I8MM platforms)SQ8BitPrepack: Tests the prepacking of the weights for the 8-bit Gemm kernelsSQ8BitGemm: Tests the 8-bit Gemm kernelsMatmulNBits contrib tests
Motivation and Context
Enable 8-bit weights Gemm on ARM64 via MLAS
Based on work and contribution by @fajin-corp
Phi-4-mini-instruct perf numbers (before and after this change):