Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
b52a1ce
finished prepack
fajin-corp Apr 28, 2025
0523106
changed interface to support blocksum2
fajin-corp Apr 29, 2025
fd92ab8
finished quantb for quant a unsigned
fajin-corp Apr 30, 2025
ed5cf8d
finished quantize a
fajin-corp May 1, 2025
b9b9691
finished Q8Int8GemmR2xC8Neon
fajin-corp May 5, 2025
685baff
finished kernels
fajin-corp May 5, 2025
6747330
fixed build
fajin-corp May 6, 2025
b087317
passed prepack
fajin-corp May 8, 2025
196c04c
finished ut for quant a
fajin-corp May 9, 2025
353d460
fixed build
fajin-corp May 9, 2025
4d62e32
Merge remote-tracking branch 'origin/main' into hari/matmul8bits_arm
hariharans29 Jun 18, 2025
e88e32d
Comment out some 4 bit tests
hariharans29 Jun 19, 2025
58011b0
Apple I8MM check
hariharans29 Jun 20, 2025
acc4b81
Tests
hariharans29 Jun 20, 2025
2700493
Tests 2
hariharans29 Jun 20, 2025
76de326
Update onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp
hariharans29 Jun 23, 2025
159d4d3
Changes
hariharans29 Jun 23, 2025
e4bc74e
Fixes
hariharans29 Jun 23, 2025
e92055b
Re-enable 4 bit tests
hariharans29 Jun 23, 2025
94f3022
Stage
hariharans29 Jun 25, 2025
61c1872
Some tests work
hariharans29 Jun 25, 2025
16da92b
Git attempt
hariharans29 Jun 25, 2025
3ce481d
Lint attempt
hariharans29 Jun 25, 2025
29f66bd
Update onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp
hariharans29 Jun 25, 2025
987574b
More changesc
hariharans29 Jun 25, 2025
d921b06
Merge branch 'hari/matmul8bits_arm' of https://github.com/microsoft/o…
hariharans29 Jun 25, 2025
cf92e6f
Fix tests
hariharans29 Jun 25, 2025
8156fc7
Stage
hariharans29 Jun 26, 2025
9a1fe22
Stage
hariharans29 Jun 26, 2025
31c8f93
Update onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp
hariharans29 Jun 26, 2025
92ec5ff
Update onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp
hariharans29 Jun 26, 2025
7159d5e
Try fix x86 builds
hariharans29 Jun 26, 2025
7ad1d36
Merge branch 'hari/matmul8bits_arm' of https://github.com/microsoft/o…
hariharans29 Jun 26, 2025
03f2916
Try fix lint errors
hariharans29 Jun 26, 2025
47420b5
Yipee zero point tests are all passing
hariharans29 Jun 27, 2025
2a5100d
Comments and Nits
hariharans29 Jun 27, 2025
d64568b
Enable MatmulNBits test
hariharans29 Jun 27, 2025
0c55755
Fixes
hariharans29 Jun 27, 2025
01d4a98
Merge remote-tracking branch 'origin/main' into hari/matmul8bits_arm
hariharans29 Jun 27, 2025
c8188d4
a
hariharans29 Jun 27, 2025
635eec9
I8MM support re-enable
hariharans29 Jun 27, 2025
f736fae
Fix warning
hariharans29 Jun 27, 2025
aa79467
Enable tests with ZP = false
hariharans29 Jun 28, 2025
10e3afa
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Jun 28, 2025
c4331e0
I8MM fixes
hariharans29 Jun 28, 2025
5b7c3af
Remove unnecessary template
hariharans29 Jun 28, 2025
9ae58ee
Resolve conflicts and update PR with more fixes
hariharans29 Jul 31, 2025
b6cd309
Fix warning
hariharans29 Jul 31, 2025
98f5fe0
Properly remove warning
hariharans29 Jul 31, 2025
0d9442b
Merge remote-tracking branch 'origin' into hari/matmul8bits_arm
hariharans29 Aug 5, 2025
9c2faa6
PR feedback
hariharans29 Sep 2, 2025
47e2420
Refine
hariharans29 Sep 3, 2025
5eb9ed9
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
bb978f7
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
9b5c389
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
76d085b
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
12e3a1d
Update onnxruntime/test/contrib_ops/matmul_8bits_test.cc
hariharans29 Sep 3, 2025
3827317
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
d8f4235
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
46aa362
Ignore sending scales while pre-packing weights on ARM64
hariharans29 Sep 3, 2025
2c956ae
Fix warning
hariharans29 Sep 3, 2025
83296bb
4 bit fix
hariharans29 Sep 3, 2025
8f14500
Merge branch 'hari/matmul8bits_arm' of https://github.com/microsoft/o…
hariharans29 Sep 3, 2025
405105b
Update onnxruntime/test/contrib_ops/matmul_8bits_test.cc
hariharans29 Sep 3, 2025
303e867
Lint
hariharans29 Sep 3, 2025
91de908
Merge branch 'hari/matmul8bits_arm' of https://github.com/microsoft/o…
hariharans29 Sep 3, 2025
890a046
Fix lintrunner mess-up once and for all
hariharans29 Sep 3, 2025
ec0c8ab
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
1f71f6c
Update onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
hariharans29 Sep 3, 2025
376fc1b
Lint
hariharans29 Sep 3, 2025
eefa72c
More fixes
hariharans29 Sep 3, 2025
e1da3d5
PR comments
hariharans29 Sep 4, 2025
77dff22
Missed out on one
hariharans29 Sep 4, 2025
7404cb3
Remove guards
hariharans29 Sep 4, 2025
edb3d72
Merge remote-tracking branch 'origin/main' into hari/matmul8bits_arm
hariharans29 Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -430,12 +431,16 @@ else()
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
)
if (onnxruntime_USE_KLEIDIAI)
setup_kleidiai()
endif()
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")

if (NOT APPLE)
set(mlas_platform_srcs
${mlas_platform_srcs}
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,20 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
has_zp_input_, nullptr, nullptr);
is_packed = true;
} else if (compute_type_ == SQNBIT_CompInt8) {
if (nbits_ == 8) {
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr,
has_zp_input_, nullptr, nullptr);
is_packed = false;
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr,
has_zp_input_, zptr, nullptr);
is_packed = false;
}
return Status::OK();
}
#ifdef MLAS_TARGET_AMD64_IX86
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS {
const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const T* QuantBBlkSum2 = nullptr; ///< optional address of scale * accumulate(quant - zp), one per block. Used when QuantA is uint8.
const T* Bias = nullptr; ///< optional address of Bias, vector size N
T* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH;

const MLAS_QNBIT_GEMM_DISPATCH&
GetMlasQNBitGemmDispatchNeon(
bool InitializeWithDotSupport
bool InitializeWithDotSupport,
bool InitializeWithI8MMSupport
);

extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;
Expand Down Expand Up @@ -1164,6 +1165,7 @@ struct MLAS_PLATFORM {
// TODO: move to cpuinfo
bool Avx2Supported_ = false;
bool Avx512Supported_ = false;
bool ArmNeonQuantAUnsigned = false;

#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER)
MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel;
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ Return Value:
const bool HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot();

if (HasDotProductInstructions) {
this->ArmNeonQuantAUnsigned = true;

this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUdot;
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUdot;
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSdot;
Expand All @@ -576,16 +578,19 @@ Return Value:
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;
}

this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions);
this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, false);

#if defined(__linux__)
//
// Check if the processor supports ASIMD I8MM instructions.
//
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
this->ArmNeonQuantAUnsigned = false;

this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla;
this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla;
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla;
this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, true);
}
#endif

Expand Down
48 changes: 32 additions & 16 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ QNBitGemmPerGemmWorkspaceSize(
}

if (BlkBitWidth == 4 || BlkBitWidth == 8) {
return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType);
return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType, BlkBitWidth);
}

return 0;
Expand Down Expand Up @@ -266,7 +266,7 @@ MlasQNBitGemmPackQuantBData(
if (BlkBitWidth == 4) {
if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) {
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
PackedQuantBDataStruct<float, 4> packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen);
PackedQuantBDataStruct<float, 4> packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen, false);
Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum(
N,
K,
Expand Down Expand Up @@ -307,7 +307,7 @@ MlasQNBitGemmPackQuantBData(
} else if (BlkBitWidth == 8) {
if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) {
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
PackedQuantBDataStruct<float, 8> packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen);
PackedQuantBDataStruct<float, 8> packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonQuantAUnsigned);
Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum(
N,
K,
Expand Down Expand Up @@ -742,6 +742,7 @@ SQ8BitGemm_CompInt8(
: static_cast<const std::byte*>(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes;
const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks;
const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks;
const float* QuantBBlkSum2 = DataParams->QuantBBlkSum2 ? DataParams->QuantBBlkSum2 + RangeStartN * k_blks : nullptr;
float* C = DataParams->C + RangeStartM * ldc + RangeStartN;

const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN;
Expand All @@ -759,6 +760,7 @@ SQ8BitGemm_CompInt8(

if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) {
const float* b_blk_sum = QuantBBlkSum + n * k_blks;
const float* b_blk_sum2 = QuantBBlkSum2 ? QuantBBlkSum2 + n * k_blks : nullptr;
GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8(
BlkLen,
QuantA,
Expand All @@ -774,7 +776,8 @@ SQ8BitGemm_CompInt8(
bias,
ldc,
ABlockSum,
b_blk_sum
b_blk_sum,
b_blk_sum2
);

if (DataParams->PostProcessor != nullptr) {
Expand All @@ -798,7 +801,8 @@ InitializeWorkspace_CompInt8(
const MLAS_QNBIT_GEMM_DATA_PARAMS<T>* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
size_t BlkBitWidth
);

template <>
Expand All @@ -812,7 +816,8 @@ InitializeWorkspace_CompInt8<float>(
const MLAS_QNBIT_GEMM_DATA_PARAMS<float>* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
size_t BlkBitWidth
)
{
MLAS_UNREFERENCED_PARAMETER(N);
Expand All @@ -825,16 +830,20 @@ InitializeWorkspace_CompInt8<float>(
const size_t BlockCountK = MlasDivRoundup(K, BlkLen);
const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen);

MLAS_UNREFERENCED_PARAMETER(QuantizeARow);
MLAS_UNREFERENCED_PARAMETER(QuantAStride);


// TODO: try parallel on BatchN * M threads because BatchN is usually 1.
if (UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) {
if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) {
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];

const float* ARowPtr = data.A;
std::byte* QuantARowPtr = static_cast<std::byte*>(Workspace) + gemm_idx * PerGemmWorkspaceStride;
QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr);
});
} else if (QuantizeARow) {
} /* else if (QuantizeARow) {
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];

Expand All @@ -847,7 +856,8 @@ InitializeWorkspace_CompInt8<float>(
QuantARowPtr += QuantAStride;
}
});
} else {
} */
else if (QuantizeARow2) {
MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) {
const auto& data = DataParams[gemm_idx];
const float* ARowPtr = data.A;
Expand Down Expand Up @@ -879,7 +889,8 @@ InitializeWorkspace_CompInt8<MLAS_FP16>(
const MLAS_QNBIT_GEMM_DATA_PARAMS<MLAS_FP16>* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
size_t BlkBitWidth
) {
MLAS_UNREFERENCED_PARAMETER(M);
MLAS_UNREFERENCED_PARAMETER(N);
Expand All @@ -890,6 +901,7 @@ InitializeWorkspace_CompInt8<MLAS_FP16>(
MLAS_UNREFERENCED_PARAMETER(Workspace);
MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride);
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
MLAS_UNREFERENCED_PARAMETER(BlkBitWidth);
}

template <typename T>
Expand All @@ -902,7 +914,8 @@ using InitializeWorkspaceFn = std::function<void(
const MLAS_QNBIT_GEMM_DATA_PARAMS<T>* DataParams,
void* Workspace,
size_t PerGemmWorkspaceStride,
MLAS_THREADPOOL* ThreadPool
MLAS_THREADPOOL* ThreadPool,
size_t BlkBitWidth
)>;

template <typename T>
Expand Down Expand Up @@ -1015,7 +1028,7 @@ MlasQNBitGemmBatch(
if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace<T>(Variant);
InitializeWorkspaceOperation != nullptr) {
InitializeWorkspaceOperation(
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool
M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool, BlkBitWidth
);
}

Expand All @@ -1029,17 +1042,19 @@ MlasQNBitGemmBatch(
void* PerGemmWorkspace =
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) {
PackedQuantBDataStruct<T, 4> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen);
PackedQuantBDataStruct<T, 4> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false);
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale;
PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen);
ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N);
} else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) {
PackedQuantBDataStruct<T, 8> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen);
PackedQuantBDataStruct<T, 8> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonQuantAUnsigned);
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBBlkSum2 = packed_quant_b.QuantBBlkSum2;

PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen);
ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N);
} else {
Expand Down Expand Up @@ -1107,18 +1122,19 @@ MlasQNBitGemmBatch(
void* PerGemmWorkspace =
reinterpret_cast<std::byte*>(Workspace) + gemm_i * PerGemmWorkspaceStride;
if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) {
PackedQuantBDataStruct<T, 4> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen);
PackedQuantBDataStruct<T, 4> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false);
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale;

PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen);
ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
} else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) {
PackedQuantBDataStruct<T, 8> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen);
PackedQuantBDataStruct<T, 8> packed_quant_b(const_cast<void*>(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonQuantAUnsigned);
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale;
const_cast<MLAS_QNBIT_GEMM_DATA_PARAMS<T>*>(Data)->QuantBBlkSum2 = packed_quant_b.QuantBBlkSum2;

PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen);
ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN);
Expand Down
26 changes: 22 additions & 4 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,39 @@ MlasAlignAddress(void* addr, const size_t alignment)

template <typename T, int BlkBitWidth>
struct PackedQuantBDataStruct {
PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen)
PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen, bool QuantAUnsigned)
: QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen)
{
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T);
if constexpr (BlkBitWidth == 8) {
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32);
} else {
#if defined(MLAS_TARGET_AMD64_IX86)
// avx512 requires alignment on a 64-byte boundary
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64);
#else
PackedQuantBData = (std::byte*)PackedQuantBWorkspace;
#endif
}

QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize);
QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment());
PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize);

if (QuantAUnsigned) {
QuantBBlkSum2 = (T*)((std::byte*)QuantBBlkSum + BlkSumSize);
QuantBBlkSum2 = (T*)MlasAlignAddress(QuantBBlkSum2, MlasQNBitQuantBBlkSumAlignment());
PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum2 + BlkSumSize);
} else {
QuantBBlkSum2 = nullptr;
PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize);
}
}

std::byte* PackedQuantBData;
T* PackedQuantBScale;
T* QuantBBlkSum;
T* QuantBBlkSum2;

void* QuantBWorkspace_;
size_t N_, BlockCountK_, BlkLen_;
Expand Down Expand Up @@ -178,7 +193,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
size_t K,
size_t BlkLen,
bool HasZeroPoint,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
size_t BlkBitWidth
);

QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr;
Expand Down Expand Up @@ -387,6 +403,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
* @param ldc Number of elements between adjacent rows of C..
* @param ABlockSum Supplies the blksum of A.
* @param QuantBBlkSum Supplies the blksum of B.
* @param QuantBBlkSum2 Supplies the blksum of B when quant A is converted to uint8.
*/
typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)(
size_t BlkLen,
Expand All @@ -403,7 +420,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
const float* Bias,
size_t ldc,
const float* ABlockSum,
const float* QuantBBlkSum
const float* QuantBBlkSum,
const float* QuantBBlkSum2
);

SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr;
Expand Down
Loading
Loading