Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.9.0.tar.gz;a2765979f64efb173a4b8ba4de39dcba9c655786
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ void CPUIDInfo::ArmLinuxInit() {
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();
has_arm_sme_ = cpuinfo_has_arm_sme();
has_arm_sme2_ = cpuinfo_has_arm_sme2();

const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
Expand Down Expand Up @@ -332,6 +333,7 @@ void CPUIDInfo::ArmAppleInit() {
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();
has_arm_sme_ = cpuinfo_has_arm_sme();
has_arm_sme2_ = cpuinfo_has_arm_sme2();

// Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect
// to encounter on Apple platforms.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class CPUIDInfo {
bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; }
bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }
bool HasArm_SME() const { return has_arm_sme_; }
bool HasArm_SME2() const { return has_arm_sme2_; }

uint32_t GetCurrentCoreIdx() const;

Expand Down Expand Up @@ -162,6 +163,7 @@ class CPUIDInfo {
bool has_arm_sve_i8mm_{false};
bool has_arm_neon_bf16_{false};
bool has_arm_sme_{false};
bool has_arm_sme2_{false};

std::string vendor_;
uint32_t vendor_id_;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#define RESTRICT __restrict__
#endif
namespace ArmKleidiAI {
// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();

//
// Buffer packing routines.
//
Expand Down
287 changes: 169 additions & 118 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
Expand Down Expand Up @@ -107,7 +108,8 @@ Routine Description:

Return Value:

None.
Returns true if the packing operation was handled by KleidiAI.
Returns false if the configuration requires a fallback to the default MLAS implementation.

--*/
{
Expand All @@ -116,9 +118,12 @@ Return Value:
}

if (TransA == CblasNoTrans) {
const size_t nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// pass zeroed bias values
const std::vector<float> bias(N);
Expand Down Expand Up @@ -152,6 +157,42 @@ ArmKleidiAI::MlasGemmBatch(
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
)
/*++

Routine Description:

This routine performs a batched matrix multiplication (GEMM) operation using KleidiAI kernels.
It handles both packed and unpacked inputs and manages tiling and kernel selection depending on
SME2 availability. If packing is needed, it prepares the required buffers and invokes the
appropriate left-hand side (LHS) and right-hand side (RHS) pack functions.

The function also applies alpha and beta scaling to the result, supports efficient memcpy
paths where possible, and dispatches tile-level GEMM work using multithreading.

Arguments:

TransA - Supplies the transpose operation for matrix A.

TransB - Supplies the transpose operation for matrix B.

M - Supplies the number of rows of matrix A and matrix C.

N - Supplies the number of columns of matrix B and matrix C.

K - Supplies the number of columns of matrix A and rows of matrix B.

Data - Supplies a pointer to the MLAS_SGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters.

BatchSize - Supplies the number of independent GEMM computations to perform in the batch.

ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles.

Return Value:

Returns true if the GEMM operation was handled by KleidiAI.
Returns false if the configuration requires a fallback to the default MLAS implementation.

--*/
{
if (M == 0 || N == 0) {
return true;
Expand All @@ -172,130 +213,134 @@ ArmKleidiAI::MlasGemmBatch(
return true;
}

if (TransA == CblasNoTrans) {
const size_t mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();

auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();
auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa();

if (M < m_step && N < n_step && !Data->BIsPacked) {
// Fallback to MLAS
return false;
}

std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
KaiPackedData.resize(BatchSize);

size_t LhsPackedStride = 0;
std::byte* LhsPackedData = nullptr;

LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr);
auto LhsPacked = std::make_unique<std::byte[]>(LhsPackedStride * BatchSize);
LhsPackedData = LhsPacked.get();

std::unique_ptr<std::byte[]> RhsPacked{nullptr};

// It is assumed all B batches require packing or not
if (Data[0].BIsPacked) {
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);

kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);

KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
KaiPackedData[batch_idx].B = Data[batch_idx].B;
});
} else {
// Multithread pack lhs and rhs
size_t RhsPackedStride = 0;
std::byte* RhsPackedData = nullptr;

RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K);
RhsPacked = std::make_unique<std::byte[]>(RhsPackedStride * BatchSize);
RhsPackedData = RhsPacked.get();

MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) {
// lhs odd, rhs even
if (batch_idx & 0x1) {
batch_idx >>= 1;

std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
} else {
batch_idx >>= 1;

std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]);

ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast<const float*>(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr);

KaiPackedData[batch_idx].B = reinterpret_cast<const float*>(RhsPackedPtr);
}
});
}
if (M < m_step && N < n_step && !Data->BIsPacked) {
// Fallback to MLAS
return false;
}

// tile iteration dimensions
std::array<size_t, 3> dim;
dim[0] = BatchSize; // B
dim[1] = MlasDivRoundup(M, m_step); // M
dim[2] = MlasDivRoundup(N, n_step); // N
std::vector<MLAS_SGEMM_DATA_PARAMS> KaiPackedData;
KaiPackedData.resize(BatchSize);

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]);
size_t LhsPackedStride = 0;
std::byte* LhsPackedData = nullptr;

// scale required tiles over available tile processors
dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]);
dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]);
LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr);
auto LhsPacked = std::make_unique<std::byte[]>(LhsPackedStride * BatchSize);
LhsPackedData = LhsPacked.get();

// compute new step sizes
m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step);
n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step);
std::unique_ptr<std::byte[]> RhsPacked{nullptr};

// update tile iterations
dim[1] = MlasDivRoundup(M, m_step);
dim[2] = MlasDivRoundup(N, n_step);
// It is assumed all B batches require packing or not
if (Data[0].BIsPacked) {
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);
KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
KaiPackedData[batch_idx].B = Data[batch_idx].B;
});
} else {
// Multithread pack lhs and rhs
size_t RhsPackedStride = 0;
std::byte* RhsPackedData = nullptr;

MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) {
// compute B,M,N index from iteration index
ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];
RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K);
RhsPacked = std::make_unique<std::byte[]>(RhsPackedStride * BatchSize);
RhsPackedData = RhsPacked.get();

// Get rhs tile, B
const size_t rhs_packed_offset =
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K);
MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) {
// lhs odd, rhs even
if (batch_idx & 0x1) {
batch_idx >>= 1;

auto BTile = reinterpret_cast<const void*>(
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].B) + rhs_packed_offset
);
std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]);

// Get lhs tile, A
const size_t lhs_packed_offset =
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K);
kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].A) + lhs_packed_offset
);
KaiPackedData[batch_idx].A = reinterpret_cast<const float*>(LhsPackedPtr);
} else {
batch_idx >>= 1;

auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step;
auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step;
std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]);

// Get result tile, C
auto CTile = reinterpret_cast<void*>(
reinterpret_cast<std::byte*>(Data[BIdx].C) +
MIdx * m_step * Data[BIdx].ldc * sizeof(float) +
NIdx * n_step * sizeof(float)
);
// Allocate temporary buffer for raw A*B result
std::vector<float> OutputTile(TileSizeM * TileSizeN, 0.0f);
float* temp_tile = OutputTile.data();
ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast<const float*>(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr);

KaiPackedData[batch_idx].B = reinterpret_cast<const float*>(RhsPackedPtr);
}
});
}

// tile iteration dimensions
std::array<size_t, 3> dim;
dim[0] = BatchSize; // B
dim[1] = MlasDivRoundup(M, m_step); // M
dim[2] = MlasDivRoundup(N, n_step); // N

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]);

// scale required tiles over available tile processors
dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]);
dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]);

// compute new step sizes
m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step);
n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step);

// update tile iterations
dim[1] = MlasDivRoundup(M, m_step);
dim[2] = MlasDivRoundup(N, n_step);

MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) {
// compute B,M,N index from iteration index
ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];

// Get rhs tile, B
const size_t rhs_packed_offset =
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K)
: kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K);

auto BTile = reinterpret_cast<const void*>(
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].B) + rhs_packed_offset
);

// Get lhs tile, A
const size_t lhs_packed_offset =
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K)
: kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(KaiPackedData[BIdx].A) + lhs_packed_offset
);

auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step;
auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step;

// Get result tile, C
auto CTile = reinterpret_cast<void*>(
reinterpret_cast<std::byte*>(Data[BIdx].C) +
MIdx * m_step * Data[BIdx].ldc * sizeof(float) +
NIdx * n_step * sizeof(float)
);
// Allocate temporary buffer for raw A*B result
std::vector<float> OutputTile(TileSizeM * TileSizeN, 0.0f);
float* temp_tile = OutputTile.data();

if (UseSME2) {
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
TileSizeM,
TileSizeN,
Expand All @@ -304,9 +349,19 @@ ArmKleidiAI::MlasGemmBatch(
TileSizeN * sizeof(float), sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
} else {
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
TileSizeM,
TileSizeN,
K,
ATile, BTile, temp_tile,
TileSizeN * sizeof(float), sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}

// Final output tile pointer
float* dst_tile = reinterpret_cast<float*>(CTile);
// Final output tile pointer
float* dst_tile = reinterpret_cast<float*>(CTile);

// quick copy of data in cases where we are not scaling or accumulating anything
// with bounds checking on tile sizing to ensure the data fits in the memory block
Expand Down Expand Up @@ -350,8 +405,4 @@ ArmKleidiAI::MlasGemmBatch(
return;
});
return true;
}
else {
return false;
}
}
Loading