Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fixes for DynamicQuantizeMatMul and Attention3D tests
Signed-off-by: Jonathan Clohessy <[email protected]>
  • Loading branch information
JonathanC-ARM committed Aug 20, 2025
commit dbf055ed5e1fd4c0307550eb8be3127e32c63a4b
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {

can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available);

// Kleidi dynamic path requires strictly positive, finite scales.
// Disable if any invalid scale is detected.
if (can_use_dynamic_quant_mlas_) {
const float* bs_data = b_scale_tensor->Data<float>();
const size_t bs_size = static_cast<size_t>(b_scale_tensor->Shape().Size());
for (size_t i = 0; i < bs_size; ++i) {
const float s = bs_data[i];
if (!std::isfinite(s) || s <= 0.0f) {
can_use_dynamic_quant_mlas_ = false;
break;
}
}
}

// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
// We check that here too before attempting to use them.
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
Expand Down
87 changes: 56 additions & 31 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,30 +153,41 @@ ArmKleidiAI::MlasGemmBatch(
MLAS_THREADPOOL* ThreadPool
)
{
if(TransA == CblasTrans)
{
return false;
if (M == 0 || N == 0) {
return true;
}
if (TransA == CblasNoTrans && K == 0) {
if (Data->beta != 1.0f) {

if (Data->alpha == 0.0f) {
if (Data->beta == 0.0f) {
for (size_t i = 0; i < M; ++i) {
std::fill_n(Data->C + i * Data->ldc, N, 0.0f);
}
} else if (Data->beta != 1.0f) {
for (size_t i = 0; i < M; ++i) {
for (size_t j = 0; j < N; ++j) {
Data->C[i * Data->ldc + j] *= Data->beta;
}
}
}
return true;
}
if (Data->beta == 0.0f){
std::fill_n(Data->C, M * Data->ldc, 0.0f);
}
//Fallback in the case of unsupported cases
if (M == 0 || N == 0 || K == 0 ||
TransA != CblasNoTrans ||
(TransB != CblasNoTrans && !Data[0].BIsPacked))
{
return false;

if (K == 0) {
if (Data->beta == 0.0f) {
for (size_t i = 0; i < M; ++i) {
std::fill_n(Data->C + i * Data->ldc, N, 0.0f);
}
} else if (Data->beta != 1.0f) {
for (size_t i = 0; i < M; ++i) {
for (size_t j = 0; j < N; ++j) {
Data->C[i * Data->ldc + j] *= Data->beta;
}
}
}
return true;
}


if (TransA == CblasNoTrans) {
const size_t mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
Expand All @@ -185,11 +196,9 @@ ArmKleidiAI::MlasGemmBatch(
auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();

if (M < m_step || N < n_step) {
if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){
//Fallback to MLAS
return false;
}
if ((M < m_step && N < n_step) && !Data->BIsPacked) {
// Fallback to MLAS
return false;
}

std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
Expand Down Expand Up @@ -316,7 +325,7 @@ ArmKleidiAI::MlasGemmBatch(
float* dst_tile = reinterpret_cast<float*>(CTile);

// quick copy of data in cases where we are not scaling or accumulating anything
// with bounds checking on tile sizing to ensure the data fits in the memory block
// with bounds checking on tile sizing to ensure the data fits in the memory block
bool can_memcpy = (
Data[BIdx].alpha == 1.0f &&
Data[BIdx].beta == 0.0f &&
Expand All @@ -328,21 +337,37 @@ ArmKleidiAI::MlasGemmBatch(

if (can_memcpy) {
std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float));
}else {
// apply alpha scaling and beta to output files
for (size_t i = 0; i < TileSizeM; ++i) {
for (size_t j = 0; j < TileSizeN; ++j) {
const size_t idx = i * TileSizeN + j;
const size_t dst_idx = i * Data[BIdx].ldc + j;

float ab = temp_tile[idx];
float c_orig = dst_tile[dst_idx];
return true;
}

dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig;
float alpha = Data[BIdx].alpha;
float beta = Data[BIdx].beta;
size_t ldc = Data[BIdx].ldc;

for (size_t i = 0; i < TileSizeM; ++i) {
for (size_t j = 0; j < TileSizeN; ++j) {
const size_t temp_idx = i * TileSizeN + j;
const size_t dst_idx = i * ldc + j;

float ab = temp_tile[temp_idx];
float c_orig = dst_tile[dst_idx];

if (alpha == 1.0f && beta == 0.0f) {
dst_tile[dst_idx] = ab;
} else if (alpha == 1.0f) {
dst_tile[dst_idx] = ab + beta * c_orig;
} else if (beta == 0.0f) {
dst_tile[dst_idx] = alpha * ab;
} else {
dst_tile[dst_idx] = alpha * ab + beta * c_orig;
}
}
}
return true;
});
return true;
}
else {
return false;
}
return true;
}
Loading