Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/android.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ jobs:
run: |
set -e -x
BINARY_SIZE_THRESHOLD_ARGS=""
echo "Binary size threshold in bytes: 1306224"
BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1306224"
echo "Binary size threshold in bytes: 1436672"
BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672"

# Ensure ANDROID_NDK_HOME is available and get its real path
if [ -z "$ANDROID_NDK_HOME" ]; then
Expand Down
16 changes: 6 additions & 10 deletions include/onnxruntime/core/framework/ortmemoryinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@
OrtMemoryInfo() = default; // to allow default construction of Tensor

// use string for name, so we could have customized allocator in execution provider.
const char* name = nullptr;
std::string name;
OrtMemType mem_type = OrtMemTypeDefault;
OrtAllocatorType alloc_type = OrtInvalidAllocator;
OrtDevice device;

constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(),
OrtMemType mem_type_ = OrtMemTypeDefault)
#if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__))
// this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5
__attribute__((nonnull))
#endif
: name(name_),
OrtMemoryInfo(std::string name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(),
OrtMemType mem_type_ = OrtMemTypeDefault)
: name(std::move(name_)),

Check warning on line 23 in include/onnxruntime/core/framework/ortmemoryinfo.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/framework/ortmemoryinfo.h:23: Add #include <utility> for move [build/include_what_you_use] [4]
mem_type(mem_type_),
alloc_type(type_),
device(device_) {
Expand All @@ -39,7 +35,7 @@
if (device != other.device)
return device < other.device;

return strcmp(name, other.name) < 0;
return name < other.name;
}

// This is to make OrtMemoryInfo a valid key in hash tables
Expand Down Expand Up @@ -68,7 +64,7 @@
return left.mem_type == other.mem_type &&
left.alloc_type == other.alloc_type &&
left.device == other.device &&
strcmp(left.name, other.name) == 0;
left.name == other.name;
}

inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); }
Expand Down
43 changes: 33 additions & 10 deletions onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/common/safeint.h"
#include "core/common/status.h"
#include "core/framework/allocator.h"
#include "core/framework/error_code_helper.h"
#include "core/mlas/inc/mlas.h"
#include "core/framework/utils.h"
#include "core/session/ort_apis.h"
Expand Down Expand Up @@ -185,22 +186,32 @@ std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return
#endif
ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1,
enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) {
API_IMPL_BEGIN

if (name1 == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "MemoryInfo name cannot be null.");
}

if (out == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output memory info cannot be null.");
}

auto device_id = static_cast<OrtDevice::DeviceId>(id1);
if (strcmp(name1, onnxruntime::CPU) == 0) {
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::CUDA, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_GPU) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::OpenVINO_GPU, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::HIP) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::HIP, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
Expand All @@ -212,45 +223,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA

} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::DML, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::OpenVINO_RT_NPU, type,
OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::CUDA_PINNED, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::HIP_PINNED, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::QNN_HTP_SHARED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::CPU_ALIGNED_4K) == 0) {
*out = new OrtMemoryInfo(
name1, type,
onnxruntime::CPU_ALIGNED_4K, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id,
onnxruntime::kAlloc4KAlignment),
mem_type1);
} else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported. Try CreateMemoryInfo_V2.");
}
API_IMPL_END
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type,
_In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type,
_In_ size_t alignment, enum OrtAllocatorType type,
_Outptr_ OrtMemoryInfo** out) {
API_IMPL_BEGIN

if (name == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "MemoryInfo name cannot be null.");
}

if (out == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output memory info cannot be null.");
}

// map the public enum values to internal OrtDevice values
OrtDevice::MemoryType mt = mem_type == OrtDeviceMemoryType_DEFAULT ? OrtDevice::MemType::DEFAULT
: OrtDevice::MemType::HOST_ACCESSIBLE;
Expand All @@ -275,6 +297,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en

*out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, narrow<int16_t>(device_id), alignment},
mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU);
API_IMPL_END
return nullptr;
}

Expand All @@ -283,7 +306,7 @@ ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { de
#pragma warning(pop)
#endif
ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out) {
*out = ptr->name;
*out = ptr->name.c_str();
return nullptr;
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/bfc_arena.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ BFCArena::BFCArena(std::unique_ptr<IAllocator> resource_allocator,
int max_dead_bytes_per_chunk,
int initial_growth_chunk_size_bytes,
int64_t max_power_of_two_extend_bytes)
: IAllocator(OrtMemoryInfo(resource_allocator->Info().name,
: IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(),
OrtAllocatorType::OrtArenaAllocator,
resource_allocator->Info().device,
resource_allocator->Info().mem_type)),
Expand Down
17 changes: 13 additions & 4 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,25 @@ struct PackedQuantBDataStruct {
{
const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen);
size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T);
if constexpr (BlkBitWidth == 8) {
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32);
} else {
#if defined(MLAS_TARGET_AMD64_IX86)
// avx512 requires alignment on a 64-byte boundary
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64);
#elif defined (MLAS_TARGET_ARM64)
// Only for 8-bit Gemms is the `PackedQuantBData` is to be 32-byte aligned and
// there is enough memory allocated to support this alignment.
// See QNBitGemmPackQuantBDataSize().
// When bit width is 4, there is no alignment guarantee.
// TODO(hasesh): Can we unify the alignment for 4-bit and 8-bit ARM64 Gemms so as to
// simpify this logic and make code here cleaner ?
if constexpr (BlkBitWidth == 8) {
PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32);
}
else {
PackedQuantBData = (std::byte*)PackedQuantBWorkspace;
}
#else
PackedQuantBData = (std::byte*)PackedQuantBWorkspace;
#endif
}

QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize);
QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ namespace Dml

bool IsGpuTensor(const onnxruntime::Tensor& tensor)
{
return strcmp(tensor.Location().name, onnxruntime::CPU) &&
return strcmp(tensor.Location().name.c_str(), onnxruntime::CPU) &&
!(tensor.Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || tensor.Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ namespace Windows::AI::MachineLearning::Adapter

bool IsAllocationInterface(const ::OrtMemoryInfo& info)
{
return strcmp(info.name, onnxruntime::CPU) && !(info.mem_type == ::OrtMemType::OrtMemTypeCPUOutput || info.mem_type == ::OrtMemType::OrtMemTypeCPUInput);
return strcmp(info.name.c_str(), onnxruntime::CPU) && !(info.mem_type == ::OrtMemType::OrtMemTypeCPUOutput || info.mem_type == ::OrtMemType::OrtMemTypeCPUInput);
}

// Translate the data object stored in a tensor to the type which will be returned through
Expand Down Expand Up @@ -1774,7 +1774,9 @@ namespace Windows::AI::MachineLearning::Adapter
}

// tells caller whether this tensor is in CPU memory
return !strcmp(m_impl->Location().name, onnxruntime::CPU) || m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput || m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput;
return !strcmp(m_impl->Location().name.c_str(), onnxruntime::CPU)
|| m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUOutput
|| m_impl->Location().mem_type == ::OrtMemType::OrtMemTypeCPUInput;
}

bool STDMETHODCALLTYPE TensorWrapper::IsDataInterface() const noexcept
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/core/providers/qnn/qnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,11 @@ struct QnnEpFactory : OrtEpFactory {
OrtKeyValuePairs* ep_options = nullptr;
factory->ort_api.CreateKeyValuePairs(&ep_options);
factory->ort_api.AddKeyValuePair(ep_options, "backend_path", factory->qnn_backend_path.c_str());
ORT_API_RETURN_IF_ERROR(
factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options,
&ep_devices[num_ep_devices++]));
OrtStatus* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options,
&ep_devices[num_ep_devices++]);

factory->ort_api.ReleaseKeyValuePairs(ep_options);
ORT_API_RETURN_IF_ERROR(status);
}
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
return tensor != nullptr &&
tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault &&
tensor->Location().device.Type() == OrtDevice::GPU &&
!strcmp(tensor->Location().name, WEBGPU_BUFFER);
!strcmp(tensor->Location().name.c_str(), WEBGPU_BUFFER);
}),
"All inputs must be tensors on WebGPU buffers.");

Expand All @@ -219,7 +219,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
return tensor != nullptr &&
tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault &&
tensor->Location().device.Type() == OrtDevice::GPU &&
!strcmp(tensor->Location().name, WEBGPU_BUFFER);
!strcmp(tensor->Location().name.c_str(), WEBGPU_BUFFER);
}),
"All outputs must be tensors on WebGPU buffers.");
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static bool AreOrtMemoryInfosEquivalent(
bool ignore_alignment = false) {
return left.mem_type == right.mem_type &&
(ignore_alignment ? left.device.EqualIgnoringAlignment(right.device) : left.device == right.device) &&
(!match_name || strcmp(left.name, right.name) == 0);
(!match_name || left.name == right.name);
}

std::vector<AllocatorPtr>::const_iterator FindExistingAllocator(const std::vector<AllocatorPtr>& allocators,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/lora_adapters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) {
static std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtMemoryInfo& mem_info) {
std::unique_ptr<IDataTransfer> data_transfer;

if (strcmp(mem_info.name, onnxruntime::CPU) == 0) {
if (mem_info.name == onnxruntime::CPU) {
return data_transfer;
}

if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) {
if (mem_info.name == onnxruntime::CUDA) {
#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)
auto* cuda_provider_info = TryGetProviderInfo_CUDA();
if (cuda_provider_info != nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/TestAllocatorManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace test {
class DummyArena : public IAllocator {
public:
explicit DummyArena(std::unique_ptr<IAllocator> resource_allocator)
: IAllocator(OrtMemoryInfo(resource_allocator->Info().name,
: IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(),
OrtAllocatorType::OrtDeviceAllocator,
resource_allocator->Info().device,
resource_allocator->Info().mem_type)),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/allocator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace test {
TEST(AllocatorTest, CPUAllocatorTest) {
auto cpu_arena = TestCPUExecutionProvider()->CreatePreferredAllocators()[0];

ASSERT_STREQ(cpu_arena->Info().name, CPU);
ASSERT_STREQ(cpu_arena->Info().name.c_str(), CPU);
EXPECT_EQ(cpu_arena->Info().device.Id(), 0);

const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage()
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/test/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void CPUTensorTest(std::vector<int64_t> dims, const int offset_elements = 0) {
EXPECT_EQ(shape.GetDims(), tensor_shape.GetDims());
EXPECT_EQ(t.DataType(), DataTypeImpl::GetType<T>());
auto& location = t.Location();
EXPECT_STREQ(location.name, CPU);
EXPECT_STREQ(location.name.c_str(), CPU);
EXPECT_EQ(location.device.Id(), 0);

const T* t_data = t.Data<T>();
Expand All @@ -47,7 +47,7 @@ void CPUTensorTest(std::vector<int64_t> dims, const int offset_elements = 0) {
EXPECT_EQ(shape.GetDims(), tensor_shape.GetDims());
EXPECT_EQ(new_t.DataType(), DataTypeImpl::GetType<T>());
auto& new_location = new_t.Location();
ASSERT_STREQ(new_location.name, CPU);
ASSERT_STREQ(new_location.name.c_str(), CPU);
EXPECT_EQ(new_location.device.Id(), 0);
}
}
Expand Down Expand Up @@ -135,7 +135,7 @@ TEST(TensorTest, EmptyTensorTest) {
EXPECT_TRUE(!data);

auto& location = t.Location();
ASSERT_STREQ(location.name, CPU);
ASSERT_STREQ(location.name.c_str(), CPU);
EXPECT_EQ(location.device.Id(), 0);

const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage()
Expand All @@ -160,7 +160,7 @@ TEST(TensorTest, StringTensorTest) {
EXPECT_EQ(shape, tensor_shape);
EXPECT_EQ(t.DataType(), DataTypeImpl::GetType<std::string>());
auto& location = t.Location();
ASSERT_STREQ(location.name, CPU);
ASSERT_EQ(location.name, CPU);
EXPECT_EQ(location.device.Id(), 0);

std::string* new_data = t.MutableData<std::string>();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/lora/lora_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
for (; begin != end; ++begin) {
const auto& [_, param] = *begin;
const auto& tensor_device = param.GetDeviceOrMapped().Get<Tensor>();
ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::CUDA));
ASSERT_EQ(0, strcmp(tensor_device.Location().name.c_str(), onnxruntime::CUDA));

const auto& tensor_cpu = param.GetMapped().Get<Tensor>();
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,8 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase {
N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer,
nullptr, HasZp, inputZp, nullptr);

PackedQuantBDataStruct<float, 8> packedQuantB(packedBuffer, N, BlkCount, BlkLen, true);
const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned;
PackedQuantBDataStruct<float, 8> packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned);

auto* C = C_.GetBuffer(M * ldc, true);
auto* ref = ref_.GetBuffer(M * ldc, true);
Expand Down Expand Up @@ -825,7 +826,9 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase {

void ExecuteShort(void) override {
Execute<1, 16, 1, 16>();
Execute<1, 1, 1, 16>();
Execute<7, 2, 4, 16>();
Execute<7, 128, 4, 16>();
Execute<8, 497, 5, 16>();
Execute<1, 3072, 128, 16>();
Execute<2, 3072, 128, 16>();
Expand Down
Loading
Loading