Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,25 @@ struct PackedQuantBDataStruct {
{
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T);
if constexpr (BlkBitWidth == 8) {
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32);
} else {
#if defined(MLAS_TARGET_AMD64_IX86)
// avx512 requires alignment on a 64-byte boundary
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64);
#elif defined (MLAS_TARGET_ARM64)
// Only for 8-bit Gemms is the `PackedQuantBData` is to be 32-byte aligned and
// there is enough memory allocated to support this alignment.
// See QNBitGemmPackQuantBDataSize().
// When bit width is 4, there is no alignment guarantee.
// TODO(hasesh): Can we unify the alignment for 4-bit and 8-bit ARM64 Gemms so as to
// simpify this logic and make code here cleaner ?
if constexpr (BlkBitWidth == 8) {
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32);
}
else {
PackedQuantBData = (std::byte*)PackedQuantBWorkspace;
}
#else
PackedQuantBData = (std::byte*)PackedQuantBWorkspace;
#endif
}

QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize);
QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment());
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,8 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase {
N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer,
nullptr, HasZp, inputZp, nullptr);

PackedQuantBDataStruct<float, 8> packedQuantB(packedBuffer, N, BlkCount, BlkLen, true);
const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned;
PackedQuantBDataStruct<float, 8> packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned);

auto* C = C_.GetBuffer(M * ldc, true);
auto* ref = ref_.GetBuffer(M * ldc, true);
Expand Down Expand Up @@ -825,7 +826,9 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase {

void ExecuteShort(void) override {
Execute<1, 16, 1, 16>();
Execute<1, 1, 1, 16>();
Execute<7, 2, 4, 16>();
Execute<7, 128, 4, 16>();
Execute<8, 497, 5, 16>();
Execute<1, 3072, 128, 16>();
Execute<2, 3072, 128, 16>();
Expand Down
Loading