Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,19 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {

can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available);

// Kleidi dynamic path requires strictly positive, finite scales.
// Disable if any invalid scale is detected.
if (can_use_dynamic_quant_mlas_) {
const auto bs = b_scale_tensor->DataAsSpan<float>();
const bool has_invalid =
std::any_of(bs.begin(), bs.end(),
[](float s) { return !std::isfinite(s) || s <= 0.0f; });

if (has_invalid) {
can_use_dynamic_quant_mlas_ = false;
}
}

// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
// We check that here too before attempting to use them.
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
Expand Down Expand Up @@ -379,7 +392,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
if (y->Shape().Size() == 0)
return Status::OK();

auto a_data = static_cast<const uint8_t*>(ctx->Input<Tensor>(IN_A)->DataRaw());
const float* a_data = ctx->Input<Tensor>(IN_A)->Data<float>();
auto* y_data = y->MutableData<float>();

// batch gemm
Expand All @@ -393,7 +406,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {

for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) {
auto& params = gemm_data_vec[gemm_idx];
params.A = reinterpret_cast<const float*>(a_data + helper.LeftOffsets()[gemm_idx]);
params.A = a_data + helper.LeftOffsets()[gemm_idx];
params.lda = gemm_shape.K;
params.PackedB = packed_b_.get();
params.C = y_data + helper.OutputOffsets()[gemm_idx];
Expand Down
73 changes: 41 additions & 32 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,28 +153,23 @@ ArmKleidiAI::MlasGemmBatch(
MLAS_THREADPOOL* ThreadPool
)
{
if(TransA == CblasTrans)
{
return false;
if (M == 0 || N == 0) {
return true;
}
if (TransA == CblasNoTrans && K == 0) {
if (Data->beta != 1.0f) {

if (Data->alpha == 0.0f || K == 0) {
if (Data->beta == 0.0f) {
for (size_t i = 0; i < M; ++i) {
std::fill_n(Data->C + i * Data->ldc, N, 0.0f);
}
} else if (Data->beta != 1.0f) {
for (size_t i = 0; i < M; ++i) {
for (size_t j = 0; j < N; ++j) {
Data->C[i * Data->ldc + j] *= Data->beta;
}
}
}
}
if (Data->beta == 0.0f){
std::fill_n(Data->C, M * Data->ldc, 0.0f);
}
//Fallback in the case of unsupported cases
if (M == 0 || N == 0 || K == 0 ||
TransA != CblasNoTrans ||
(TransB != CblasNoTrans && !Data[0].BIsPacked))
{
return false;
return true;
}

if (TransA == CblasNoTrans) {
Expand All @@ -185,11 +180,9 @@ ArmKleidiAI::MlasGemmBatch(
auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();

if (M < m_step || N < n_step) {
if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){
//Fallback to MLAS
return false;
}
if (M < m_step && N < n_step && !Data->BIsPacked) {
// Fallback to MLAS
return false;
}

std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
Expand Down Expand Up @@ -316,7 +309,7 @@ ArmKleidiAI::MlasGemmBatch(
float* dst_tile = reinterpret_cast<float*>(CTile);

// quick copy of data in cases where we are not scaling or accumulating anything
// with bounds checking on tile sizing to ensure the data fits in the memory block
// with bounds checking on tile sizing to ensure the data fits in the memory block
bool can_memcpy = (
Data[BIdx].alpha == 1.0f &&
Data[BIdx].beta == 0.0f &&
Expand All @@ -328,21 +321,37 @@ ArmKleidiAI::MlasGemmBatch(

if (can_memcpy) {
std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float));
}else {
// apply alpha scaling and beta to output files
for (size_t i = 0; i < TileSizeM; ++i) {
for (size_t j = 0; j < TileSizeN; ++j) {
const size_t idx = i * TileSizeN + j;
const size_t dst_idx = i * Data[BIdx].ldc + j;

float ab = temp_tile[idx];
float c_orig = dst_tile[dst_idx];
return;
}

dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig;
float alpha = Data[BIdx].alpha;
float beta = Data[BIdx].beta;
size_t ldc = Data[BIdx].ldc;

for (size_t i = 0; i < TileSizeM; ++i) {
for (size_t j = 0; j < TileSizeN; ++j) {
const size_t temp_idx = i * TileSizeN + j;
const size_t dst_idx = i * ldc + j;

float ab = temp_tile[temp_idx];
float c_orig = dst_tile[dst_idx];

if (alpha == 1.0f && beta == 0.0f) {
dst_tile[dst_idx] = ab;
} else if (alpha == 1.0f) {
dst_tile[dst_idx] = ab + beta * c_orig;
} else if (beta == 0.0f) {
dst_tile[dst_idx] = alpha * ab;
} else {
dst_tile[dst_idx] = alpha * ab + beta * c_orig;
}
}
}
return;
});
return true;
}
else {
return false;
}
return true;
}
Loading