From b4d090226b28ff93b64499c68430f77acf887195 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 11 Aug 2025 19:06:47 -0700 Subject: [PATCH 01/12] CXX EP related API beings --- .../core/graph/indexed_sub_graph.h | 2 +- .../core/session/onnxruntime_c_api.h | 2 +- .../core/session/onnxruntime_cxx_api.h | 233 ++++++++++-------- .../core/session/onnxruntime_cxx_inline.h | 161 ++++++++++-- .../core/session/abi_key_value_pairs.h | 4 +- onnxruntime/core/session/onnxruntime_c_api.cc | 15 +- onnxruntime/core/session/ort_apis.h | 4 +- onnxruntime/test/autoep/test_allocators.cc | 54 ++-- onnxruntime/test/autoep/test_data_transfer.cc | 28 +-- onnxruntime/test/autoep/test_execution.cc | 12 +- onnxruntime/test/providers/cpu/model_tests.cc | 30 +-- onnxruntime/test/shared_lib/test_allocator.cc | 10 +- onnxruntime/test/shared_lib/test_data_copy.cc | 64 ++--- onnxruntime/test/shared_lib/test_inference.cc | 61 ++--- onnxruntime/test/util/include/api_asserts.h | 43 ++-- 15 files changed, 404 insertions(+), 319 deletions(-) diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index 088db79a7e005..8ef4fdb66e1e6 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -31,7 +31,7 @@ struct IndexedSubGraph { std::string domain; ///< Domain of customized SubGraph/FunctionProto int since_version; ///< Since version of customized SubGraph/FunctionProto. - ONNX_NAMESPACE::OperatorStatus status; ///< Status of customized SubGraph/FunctionProto. + ONNX_NAMESPACE::OperatorStatus status{ONNX_NAMESPACE::OperatorStatus::STABLE}; ///< Status of customized SubGraph/FunctionProto. std::vector inputs; ///< Inputs of customized SubGraph/FunctionProto. std::vector outputs; ///< Outputs of customized SubGraph/FunctionProto. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6eb15280a4aa4..3974dadc985f6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5829,7 +5829,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph, + ORT_API2_STATUS(Graph_GetNodes, _In_ const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); /** \brief Get the parent node for the given graph, if any exists. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d1b08f127fa2a..cee2300daae54 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -566,6 +566,7 @@ ORT_DEFINE_RELEASE(ModelMetadata); ORT_DEFINE_RELEASE(IoBinding); ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(Status); +ORT_DEFINE_RELEASE(SyncStream); ORT_DEFINE_RELEASE(OpAttr); ORT_DEFINE_RELEASE(Op); ORT_DEFINE_RELEASE(KernelInfo); @@ -697,6 +698,9 @@ struct Model; struct Node; struct ModelMetadata; struct TypeInfo; +struct Session; +struct SessionOptions; +struct SyncStream; struct Value; struct ValueInfo; @@ -715,7 +719,7 @@ struct Status : detail::Base { using Base::Base; explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used - explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. @@ -793,6 +797,111 @@ struct KeyValuePairs : detail::KeyValuePairsImpl { ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; } }; +namespace detail { +template +struct MemoryInfoImpl : Base { + using B = Base; + using B::B; + + std::string GetAllocatorName() const; ///< Wrapper MemoryInfoGetName + OrtAllocatorType GetAllocatorType() const; ///< Wrapper MemoryInfoGetType + int GetDeviceId() const; ///< Wrapper MemoryInfoGetId + OrtMemoryInfoDeviceType GetDeviceType() const; ///< Wrapper MemoryInfoGetDeviceType + OrtMemType GetMemoryType() const; ///< Wrapper MemoryInfoGetMemType + OrtDeviceMemoryType GetDeviceMemoryType() const; ///< Wrapper MemoryInfoGetDeviceMemType + uint32_t GetVendorId() const; ///< Wrapper MemoryInfoGetVendorId + + template + bool operator==(const MemoryInfoImpl& o) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstMemoryInfo = detail::MemoryInfoImpl>; + +/** \brief Wrapper around ::OrtMemoryInfo + * + */ +struct MemoryInfo : detail::MemoryInfoImpl { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); + MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, + OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type); ///< Wrapper around CreateMemoryInfo_V2 + ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } +}; + +/// +/// Represents native memory allocation coming from one of the +/// OrtAllocators registered with OnnxRuntime. +/// Use it to wrap an allocation made by an allocator +/// so it can be automatically released when no longer needed. +/// +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&) noexcept; + MemoryAllocation& operator=(MemoryAllocation&&) noexcept; + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +namespace detail { +template +struct AllocatorImpl : Base { + using B = Base; + using B::B; + + void* Alloc(size_t size); + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + ConstMemoryInfo GetInfo() const; + + /** \brief Function that returns the statistics of the allocator. + * + * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics. + */ + KeyValuePairs GetStats() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime + * + */ +struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { + explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + AllocatorWithDefaultOptions(); +}; + +/** \brief Wrapper around ::OrtAllocator + * + */ + +struct Allocator : detail::AllocatorImpl { + explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + Allocator(const Session& session, const OrtMemoryInfo*); +}; + +using UnownedAllocator = detail::AllocatorImpl>; + +/** \brief Wrapper around ::OrtSyncStream + * + */ +struct SyncStream : detail::Base { + explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API + void* GetHandle() const; ///< Wraps SyncStream_GetHandle +}; + namespace detail { template struct HardwareDeviceImpl : Ort::detail::Base { @@ -823,6 +932,8 @@ struct EpDeviceImpl : Ort::detail::Base { ConstKeyValuePairs EpMetadata() const; ConstKeyValuePairs EpOptions() const; ConstHardwareDevice Device() const; + ConstMemoryInfo GetMemoryInfo(OrtDeviceMemoryType memory_type) const; ///< Wraps EpDevice_MemoryInfo + SyncStream CreateSyncStream(ConstKeyValuePairs stream_options = {}) const; /// Wraps EpDevice_CreateSyncStream }; } // namespace detail @@ -877,10 +988,28 @@ struct Env : detail::Base { const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + Env& RegisterAllocator(OrtAllocator* allocator); ///< Wraps OrtApi::RegisterAllocator + + Env& UnregisterAllocator(const OrtMemoryInfo* mem_info); ///< Wraps OrtApi::UnregisterAllocator + + UnownedAllocator CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options); ///< Wraps OrtApi::CreateSharedAllocator + + // Result may be nullptr + UnownedAllocator GetSharedAllocator(const OrtMemoryInfo* mem_info); ///< Wraps OrtApi::GetSharedAllocator + + void ReleaseSharedAllocator(const OrtEpDevice* ep_device, + OrtDeviceMemoryType mem_type); ///< Wraps OrtApi::ReleaseSharedAllocator + Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary std::vector GetEpDevices() const; + + Status CopyTensors(const std::vector& src_tensors, + const std::vector& dst_tensors, + OrtSyncStream* stream) const; ///< Wraps OrtApi::CopyTensors }; /** \brief Custom Op Domain @@ -1018,8 +1147,6 @@ struct CustomOpConfigs { * Wraps ::OrtSessionOptions object and methods */ -struct SessionOptions; - namespace detail { // we separate const-only methods because passing const ptr to non-const methods // is only discovered when inline methods are compiled which is counter-intuitive @@ -1264,6 +1391,10 @@ struct ConstSessionImpl : Base { std::vector GetOutputNames() const; std::vector GetOverridableInitializerNames() const; + std::vector GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs + std::vector GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs + std::vector GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs + /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1427,37 +1558,6 @@ struct Session : detail::SessionImpl { UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } }; -namespace detail { -template -struct MemoryInfoImpl : Base { - using B = Base; - using B::B; - - std::string GetAllocatorName() const; - OrtAllocatorType GetAllocatorType() const; - int GetDeviceId() const; - OrtMemoryInfoDeviceType GetDeviceType() const; - OrtMemType GetMemoryType() const; - - template - bool operator==(const MemoryInfoImpl& o) const; -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstMemoryInfo = detail::MemoryInfoImpl>; - -/** \brief Wrapper around ::OrtMemoryInfo - * - */ -struct MemoryInfo : detail::MemoryInfoImpl { - static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); - explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API - MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); - ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } -}; - namespace detail { template struct TensorTypeAndShapeInfoImpl : Base { @@ -1686,7 +1786,7 @@ struct ConstValueImpl : Base { /// /// const pointer to data, no copies made template - const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// + const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorData /// /// /// Returns a non-typed pointer to a tensor contained data. @@ -1956,7 +2056,7 @@ struct Value : detail::ValueImpl { using OrtSparseValuesParam = detail::OrtSparseValuesParam; using Shape = detail::Shape; - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used + Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used Value(Value&&) = default; Value& operator=(Value&&) = default; @@ -2121,67 +2221,6 @@ struct Value : detail::ValueImpl { #endif // !defined(DISABLE_SPARSE_TENSORS) }; -/// -/// Represents native memory allocation coming from one of the -/// OrtAllocators registered with OnnxRuntime. -/// Use it to wrap an allocation made by an allocator -/// so it can be automatically released when no longer needed. -/// -struct MemoryAllocation { - MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); - ~MemoryAllocation(); - MemoryAllocation(const MemoryAllocation&) = delete; - MemoryAllocation& operator=(const MemoryAllocation&) = delete; - MemoryAllocation(MemoryAllocation&&) noexcept; - MemoryAllocation& operator=(MemoryAllocation&&) noexcept; - - void* get() { return p_; } - size_t size() const { return size_; } - - private: - OrtAllocator* allocator_; - void* p_; - size_t size_; -}; - -namespace detail { -template -struct AllocatorImpl : Base { - using B = Base; - using B::B; - - void* Alloc(size_t size); - MemoryAllocation GetAllocation(size_t size); - void Free(void* p); - ConstMemoryInfo GetInfo() const; - - /** \brief Function that returns the statistics of the allocator. - * - * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics. - */ - KeyValuePairs GetStats() const; -}; - -} // namespace detail - -/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime - * - */ -struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { - explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - AllocatorWithDefaultOptions(); -}; - -/** \brief Wrapper around ::OrtAllocator - * - */ -struct Allocator : detail::AllocatorImpl { - explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - Allocator(const Session& session, const OrtMemoryInfo*); -}; - -using UnownedAllocator = detail::AllocatorImpl>; - namespace detail { namespace binding_utils { // Bring these out of template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 705f17c5d6f43..201278c5c132b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -13,6 +13,7 @@ #include #include #include +#include "onnxruntime_cxx_api.h" // Convert OrtStatus to Ort::Status and return // instead of throwing @@ -296,6 +297,16 @@ inline OrtMemType MemoryInfoImpl::GetMemoryType() const { return type; } +template +inline OrtDeviceMemoryType MemoryInfoImpl::GetDeviceMemoryType() const { + return GetApi().MemoryInfoGetDeviceMemType(this->p_); +} + +template +inline uint32_t MemoryInfoImpl::GetVendorId() const { + return GetApi().MemoryInfoGetVendorId(this->p_); +} + template template inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { @@ -316,6 +327,12 @@ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, O ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); } +inline MemoryInfo::MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, + OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type) { + ThrowOnError(GetApi().CreateMemoryInfo_V2(name, device_type, vendor_id, device_id, mem_type, alignment, + allocator_type, &this->p_)); +} + namespace detail { template inline std::vector ConstIoBindingImpl::GetOutputNames() const { @@ -404,20 +421,7 @@ inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { std::vector result; - size_t owned = 0; size_t output_count = 0; - // Lambda to release the buffer when no longer needed and - // make sure that we destroy all instances on exception - auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { - if (buffer) { - while (owned < output_count) { - auto* p = buffer + owned++; - GetApi().ReleaseValue(*p); - } - allocator->Free(allocator, buffer); - } - }; - using Ptr = std::unique_ptr; OrtValue** output_buffer = nullptr; ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); @@ -425,12 +429,11 @@ inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, Ort return result; } - Ptr buffer_g(output_buffer, free_fn); + std::unique_ptr buffer_g(output_buffer, AllocatedFree(allocator)); result.reserve(output_count); for (size_t i = 0; i < output_count; ++i) { result.emplace_back(output_buffer[i]); - ++owned; } return result; } @@ -547,6 +550,10 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } +inline void* SyncStream::GetHandle() const { + return GetApi().SyncStream_GetHandle(this->p_); +} + namespace detail { template inline OrtHardwareDeviceType HardwareDeviceImpl::Type() const { @@ -597,6 +604,19 @@ template inline ConstHardwareDevice EpDeviceImpl::Device() const { return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_)); } + +template +inline ConstMemoryInfo EpDeviceImpl::GetMemoryInfo(OrtDeviceMemoryType memory_type) const { + const auto* mem_info = GetApi().EpDevice_MemoryInfo(this->p_, memory_type); + return ConstMemoryInfo{mem_info}; +} + +template +inline SyncStream EpDeviceImpl::CreateSyncStream(ConstKeyValuePairs stream_options) const { + OrtSyncStream* stream = nullptr; + ThrowOnError(GetApi().CreateSyncStreamForEpDevice(this->p_, stream_options, &stream)); + return SyncStream{stream}; +} } // namespace detail inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, @@ -676,6 +696,16 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, return *this; } +inline Env& Env::RegisterAllocator(OrtAllocator* allocator) { + ThrowOnError(GetApi().RegisterAllocator(p_, allocator)); + return *this; +} + +inline Env& Env::UnregisterAllocator(const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().UnregisterAllocator(p_, mem_info)); + return *this; +} + inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path) { ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str())); @@ -703,6 +733,41 @@ inline std::vector Env::GetEpDevices() const { return devices; } +inline Status Env::CopyTensors(const std::vector& src_tensors, + const std::vector& dst_tensors, + OrtSyncStream* stream) const { + if (src_tensors.size() != dst_tensors.size()) { + return Status("Source and destination tensor vectors must have the same size", ORT_INVALID_ARGUMENT); + } + if (src_tensors.empty()) { + return Status(); + } + + const OrtValue* const* src_tensors_ptr = reinterpret_cast(src_tensors.data()); + OrtValue* const* dst_tensors_ptr = reinterpret_cast(dst_tensors.data()); + OrtStatus* status = GetApi().CopyTensors(p_, src_tensors_ptr, dst_tensors_ptr, stream, src_tensors.size()); + return Status(status); +} + +inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options) { + OrtAllocator* p; + ThrowOnError(GetApi().CreateSharedAllocator(p_, ep_device, mem_type, allocator_type, allocator_options, &p)); + return UnownedAllocator{p}; +} + +inline UnownedAllocator Env::GetSharedAllocator(const OrtMemoryInfo* mem_info) { + OrtAllocator* p; + ThrowOnError(GetApi().GetSharedAllocator(p_, mem_info, &p)); + return UnownedAllocator{p}; +} + +inline void Env::ReleaseSharedAllocator(const OrtEpDevice* ep_device, + OrtDeviceMemoryType mem_type) { + ThrowOnError(GetApi().ReleaseSharedAllocator(p_, ep_device, mem_type)); +} + inline CustomOpDomain::CustomOpDomain(const char* domain) { ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); } @@ -1298,9 +1363,9 @@ inline std::vector ConstSessionImpl::GetInputNames() const { input_names.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); - input_names.push_back(name); + input_names.emplace_back(name); allocator.Free(name); } @@ -1316,9 +1381,9 @@ inline std::vector ConstSessionImpl::GetOutputNames() const { output_names.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); - output_names.push_back(name); + output_names.emplace_back(name); allocator.Free(name); } @@ -1334,14 +1399,44 @@ inline std::vector ConstSessionImpl::GetOverridableInitializerNa initializer_names.reserve(num_initializers); for (size_t i = 0; i < num_initializers; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); - initializer_names.push_back(name); + initializer_names.emplace_back(name); } return initializer_names; } +template +inline std::vector ConstSessionImpl::GetMemoryInfoForInputs() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetInputCount(); + std::vector mem_infos; + mem_infos.resize(num_inputs); + + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, reinterpret_cast(&mem_infos[0]), + num_inputs)); + + return mem_infos; +} + +template +inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs() const { + AllocatorWithDefaultOptions allocator; + auto num_outputs = GetOutputCount(); + std::vector mem_infos; + mem_infos.reserve(num_outputs); + + const OrtMemoryInfo* mem_info_ptrs; + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, &mem_info_ptrs, num_outputs)); + for (size_t i = 0; i < num_outputs; ++i) { + mem_infos.emplace_back(mem_info_ptrs[i]); + } + + return mem_infos; +} + template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1363,6 +1458,22 @@ inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllo return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } +template +inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { + auto num_inputs = GetInputCount(); + std::vector input_devices; + input_devices.reserve(num_inputs); + + const OrtEpDevice* device_ptrs; + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, &device_ptrs, num_inputs)); + + for (size_t i = 0; i < num_inputs; ++i) { + input_devices.emplace_back(device_ptrs[i]); + } + + return input_devices; +} + template inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { uint64_t out; @@ -1857,15 +1968,15 @@ inline size_t ConstValueImpl::GetTensorSizeInBytes() const { template template inline const R* ConstValueImpl::GetTensorData() const { - R* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); + const R* out; + ThrowOnError(GetApi().GetTensorData(this->p_, reinterpret_cast(&out))); return out; } template inline const void* ConstValueImpl::GetTensorRawData() const { - void* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); + const void* out; + ThrowOnError(GetApi().GetTensorData(this->p_, &out)); return out; } diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 7d739439b7a27..72d5007b84e6f 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -18,11 +18,11 @@ struct OrtKeyValuePairs { CopyFromMap(other.entries_); } - OrtKeyValuePairs(OrtKeyValuePairs&& other) : OrtKeyValuePairs{} { + OrtKeyValuePairs(OrtKeyValuePairs&& other) noexcept : OrtKeyValuePairs{} { swap(*this, other); } - OrtKeyValuePairs& operator=(OrtKeyValuePairs other) { // handles copy and move assignment + OrtKeyValuePairs& operator=(OrtKeyValuePairs other) noexcept { // handles copy and move assignment swap(*this, other); return *this; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 88d84e95b406c..bb6ca0c5d1d93 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1094,7 +1094,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorMutableData, _Inout_ OrtValue* value, _Out API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _Inout_ const OrtValue* value, _Outptr_ const void** output) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** output) { TENSOR_READ_API_BEGIN *output = tensor.DataRaw(); return nullptr; @@ -2761,7 +2761,8 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetNodes, _In_ const OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_ const OrtNode** node) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, + _Outptr_result_maybenull_ const OrtNode** node) { API_IMPL_BEGIN if (node == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node' argument is NULL"); @@ -3209,7 +3210,7 @@ ORT_API(void, OrtApis::GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, *num_entries = kvps->Entries().size(); } -ORT_API(void, OrtApis::RemoveKeyValuePair, _Frees_ptr_opt_ OrtKeyValuePairs* kvps, _In_ const char* key) { +ORT_API(void, OrtApis::RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key) { kvps->Remove(key); } @@ -3218,7 +3219,7 @@ ORT_API(void, OrtApis::ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs* k } #if !defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path) { API_IMPL_BEGIN ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().RegisterExecutionProviderLibrary(registration_name, path)); @@ -3226,7 +3227,7 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name) { +ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name) { API_IMPL_BEGIN ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().UnregisterExecutionProviderLibrary(registration_name)); return nullptr; @@ -3413,7 +3414,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, } #else // defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); @@ -3421,7 +3422,7 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*en } ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, - const char* /*registration_name*/) { + _In_ const char* /*registration_name*/) { API_IMPL_BEGIN return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); API_IMPL_END diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3eee174ff81f4..de8a51a0aedb5 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -562,7 +562,7 @@ ORT_API(void, GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, ORT_API(void, RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key); ORT_API(void, ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs*); -ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path); ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); @@ -652,7 +652,7 @@ ORT_API_STATUS_IMPL(Graph_GetInitializers, _In_ const OrtGraph* graph, _Out_writes_(num_initializers) const OrtValueInfo** initializers, _In_ size_t num_initializers); ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* num_nodes); -ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, +ORT_API_STATUS_IMPL(Graph_GetNodes, _In_ const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 77d2bb24b7d35..88b522eb10dca 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -60,66 +60,58 @@ struct DummyAllocator : OrtAllocator { // validate CreateSharedAllocator allows adding an arena to the shared allocator TEST(SharedAllocators, AddArenaToSharedAllocator) { - const OrtApi& c_api = Ort::GetApi(); RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(example_ep.get(), OrtDeviceMemoryType_DEFAULT); + Ort::ConstEpDevice example_ep_device{example_ep.get()}; + + auto ep_memory_info = example_ep_device.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); // validate there is a shared allocator - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(ep_memory_info); ASSERT_NE(allocator, nullptr); // call CreateSharedAllocator to replace with arena based allocator. arena is configured with kvps - OrtKeyValuePairs allocator_options; + Ort::KeyValuePairs allocator_options; auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); - ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT, - // allocator is internally added by EP. - // OrtArenaAllocator can only be used for the internal BFCArena - OrtDeviceAllocator, - &allocator_options, &allocator)); + allocator = ort_env->CreateSharedAllocator(example_ep.get(), OrtDeviceMemoryType_DEFAULT, + // allocator is internally added by EP. + // OrtArenaAllocator can only be used for the internal BFCArena + OrtDeviceAllocator, + allocator_options); // first allocation should init the arena to the initial chunk size - void* mem = allocator->Alloc(allocator, 16); - allocator->Free(allocator, mem); + void* mem = allocator.Alloc(16); + allocator.Free(mem); // stats should prove the arena was used - OrtKeyValuePairs* allocator_stats = nullptr; - ASSERT_ORTSTATUS_OK(allocator->GetStats(allocator, &allocator_stats)); + auto allocator_stats = allocator.GetStats(); using ::testing::Contains; using ::testing::Pair; - const auto& stats = allocator_stats->Entries(); + const auto& stats = static_cast(allocator_stats)->Entries(); EXPECT_THAT(stats, Contains(Pair("NumAllocs", "1"))); EXPECT_THAT(stats, Contains(Pair("NumArenaExtensions", "1"))); EXPECT_THAT(stats, Contains(Pair("TotalAllocated", initial_chunk_size))); // optional. ORT owns the allocator but we want to test the release implementation - ASSERT_ORTSTATUS_OK(c_api.ReleaseSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT)); + ort_env->ReleaseSharedAllocator(example_ep.get(), OrtDeviceMemoryType_DEFAULT); } TEST(SharedAllocators, GetSharedAllocator) { - const OrtApi& c_api = Ort::GetApi(); - // default CPU allocator should be available. // create a memory info with a different name to validate the shared allocator lookup ignores the name - OrtMemoryInfo* test_cpu_memory_info = nullptr; - ASSERT_ORTSTATUS_OK(c_api.CreateMemoryInfo_V2("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, - OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator, - &test_cpu_memory_info)); + auto test_cpu_memory_info = Ort::MemoryInfo("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, + OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator); const auto get_allocator_and_check_name = [&](const std::string& expected_name) { - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, test_cpu_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(test_cpu_memory_info); ASSERT_NE(allocator, nullptr); - const OrtMemoryInfo* ort_cpu_memory_info = nullptr; - ASSERT_ORTSTATUS_OK(c_api.AllocatorGetInfo(allocator, &ort_cpu_memory_info)); - const char* allocator_name; - ASSERT_ORTSTATUS_OK(c_api.MemoryInfoGetName(ort_cpu_memory_info, &allocator_name)); + auto ort_cpu_memory_info = allocator.GetInfo(); + auto allocator_name = ort_cpu_memory_info.GetAllocatorName(); ASSERT_EQ(expected_name, allocator_name); // Default ORT CPU allocator }; @@ -128,18 +120,16 @@ TEST(SharedAllocators, GetSharedAllocator) { // register custom allocator and make sure that is accessible by exact match DummyAllocator dummy_alloc{test_cpu_memory_info}; - c_api.RegisterAllocator(*ort_env, &dummy_alloc); + ort_env->RegisterAllocator(&dummy_alloc); // GetSharedAllocator should now match the custom allocator get_allocator_and_check_name("dummy"); // unregister custom allocator - ASSERT_ORTSTATUS_OK(c_api.UnregisterAllocator(*ort_env, test_cpu_memory_info)); + ort_env->UnregisterAllocator(test_cpu_memory_info); // there should always be a CPU allocator available get_allocator_and_check_name(onnxruntime::CPU); - - c_api.ReleaseMemoryInfo(test_cpu_memory_info); } } // namespace test diff --git a/onnxruntime/test/autoep/test_data_transfer.cc b/onnxruntime/test/autoep/test_data_transfer.cc index cc09699b754b6..71c69698ed386 100644 --- a/onnxruntime/test/autoep/test_data_transfer.cc +++ b/onnxruntime/test/autoep/test_data_transfer.cc @@ -23,16 +23,15 @@ namespace onnxruntime { namespace test { TEST(OrtEpLibrary, DataTransfer) { - const OrtApi& c_api = Ort::GetApi(); RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* ep_device = example_ep.get(); + Ort::ConstEpDevice ep_device(example_ep.get()); - const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + auto device_memory_info = ep_device.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); // create a tensor using the default CPU allocator Ort::AllocatorWithDefaultOptions cpu_allocator; - std::vector shape{2, 3, 4}; // shape doesn't matter + constexpr const std::array shape{2, 3, 4}; // shape doesn't matter const size_t num_elements = 2 * 3 * 4; RandomValueGenerator random{}; @@ -44,24 +43,21 @@ TEST(OrtEpLibrary, DataTransfer) { // create an on-device Tensor using the example EPs alternative CPU allocator. // it has a different vendor to the default ORT CPU allocator so we can copy between them even though both are // really CPU based. - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); ASSERT_NE(allocator, nullptr); Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); - std::vector src_tensor_ptrs{cpu_tensor}; - std::vector dst_tensor_ptrs{device_tensor}; + std::vector src_tensor; + src_tensor.push_back(std::move(cpu_tensor)); + std::vector dst_tensor; + dst_tensor.push_back(std::move(device_tensor)); - ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, - src_tensor_ptrs.size())); + ASSERT_CXX_ORTSTATUS_OK(ort_env->CopyTensors(src_tensor, dst_tensor, nullptr)); - const float* src_data = nullptr; - const float* dst_data = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); - ASSERT_ORTSTATUS_OK(c_api.GetTensorData(device_tensor, reinterpret_cast(&dst_data))); + const float* src_data = src_tensor[0].GetTensorData(); + const float* dst_data = dst_tensor[0].GetTensorData(); - size_t bytes; - ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + size_t bytes = src_tensor[0].GetTensorSizeInBytes(); ASSERT_EQ(bytes, num_elements * sizeof(float)); ASSERT_NE(src_data, dst_data) << "Should have copied between two different memory locations"; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index f1ef67e1f6ba4..0f4a654f116c4 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -54,12 +54,12 @@ void RunModelWithPluginEp(Ort::SessionOptions& session_options) { TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* plugin_ep_device = example_ep.get(); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); // Create session with example plugin EP Ort::SessionOptions session_options; std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); RunModelWithPluginEp(session_options); } @@ -83,7 +83,7 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* plugin_ep_device = example_ep.get(); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); { const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); @@ -94,7 +94,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { Ort::SessionOptions session_options; std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); @@ -102,9 +102,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { compile_options.SetOutputModelPath(output_model_file); // Compile the model. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); - + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); // Make sure the compiled model was generated. ASSERT_TRUE(std::filesystem::exists(output_model_file)); } diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index d5f6f1ddf700e..ed9429d5eec54 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -61,7 +61,7 @@ extern std::unique_ptr ort_env; // asserts that the OrtStatus* result of `status_expr` does not indicate an error // note: this takes ownership of the OrtStatus* result -#define ASSERT_ORT_STATUS_OK(status_expr) \ +#define ASSERT_CXX_ORTSTATUS_OK(status_expr) \ do { \ if (OrtStatus* _status = (status_expr); _status != nullptr) { \ std::unique_ptr _rel_status{ \ @@ -180,7 +180,7 @@ TEST_P(ModelTest, Run) { ortso.SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); if (provider_name == "cuda") { OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); std::unique_ptr rel_cuda_options( cuda_options, &OrtApis::ReleaseCUDAProviderOptions); @@ -189,7 +189,7 @@ TEST_P(ModelTest, Run) { std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); values.push_back(device_id.empty() ? "0" : device_id.c_str()); values.push_back("0"); - ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "rocm") { @@ -199,11 +199,11 @@ TEST_P(ModelTest, Run) { #ifdef USE_DNNL else if (provider_name == "dnnl") { OrtDnnlProviderOptions* ep_option; - ASSERT_ORT_STATUS_OK(OrtApis::CreateDnnlProviderOptions(&ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateDnnlProviderOptions(&ep_option)); std::unique_ptr rel_dnnl_options(ep_option, &OrtApis::ReleaseDnnlProviderOptions); ep_option->use_arena = 0; - ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_Dnnl(ortso, ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_Dnnl(ortso, ep_option)); } #endif else if (provider_name == "tensorrt") { @@ -212,14 +212,14 @@ TEST_P(ModelTest, Run) { ortso.AppendExecutionProvider_TensorRT_V2(params); } else { OrtTensorRTProviderOptionsV2* ep_option = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateTensorRTProviderOptions(&ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateTensorRTProviderOptions(&ep_option)); std::unique_ptr rel_cuda_options(ep_option, &OrtApis::ReleaseTensorRTProviderOptions); ortso.AppendExecutionProvider_TensorRT_V2(*ep_option); } // Enable CUDA fallback OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); std::unique_ptr rel_cuda_options( cuda_options, &OrtApis::ReleaseCUDAProviderOptions); @@ -228,7 +228,7 @@ TEST_P(ModelTest, Run) { std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); values.push_back(device_id.empty() ? "0" : device_id.c_str()); values.push_back("0"); - ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "migraphx") { @@ -240,27 +240,27 @@ TEST_P(ModelTest, Run) { } #ifdef USE_NNAPI else if (provider_name == "nnapi") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); } #endif #ifdef USE_VSINPU else if (provider_name == "vsinpu") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); } #endif #ifdef USE_RKNPU else if (provider_name == "rknpu") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); } #endif #ifdef USE_ACL else if (provider_name == "acl") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); } #endif #ifdef USE_ARMNN else if (provider_name == "armnn") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso)); } #endif #ifdef USE_XNNPACK @@ -300,11 +300,11 @@ TEST_P(ModelTest, Run) { std::unordered_map feeds; l->LoadTestData(task_id, holder, feeds, true); size_t output_count; - ASSERT_ORT_STATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); // Create output feed std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { - ASSERT_ORT_STATUS_OK( + ASSERT_CXX_ORTSTATUS_OK( OrtApis::SessionGetOutputName(ort_session, i, default_allocator.get(), &output_names[i])); } diff --git a/onnxruntime/test/shared_lib/test_allocator.cc b/onnxruntime/test/shared_lib/test_allocator.cc index 29f3dfad0f11d..bf9e54e8b3c7b 100644 --- a/onnxruntime/test/shared_lib/test_allocator.cc +++ b/onnxruntime/test/shared_lib/test_allocator.cc @@ -45,12 +45,10 @@ TEST(CApiTest, DefaultAllocator) { TEST(CApiTest, CustomAllocator) { constexpr PATH_TYPE model_path = TSTR("testdata/mul_1.onnx"); - const auto& api = Ort::GetApi(); - // Case 1: Register a custom allocator. { MockedOrtAllocator mocked_allocator; - ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + ort_env->RegisterAllocator(&mocked_allocator); Ort::SessionOptions session_options; session_options.AddConfigEntry("session.use_env_allocators", "1"); @@ -62,14 +60,14 @@ TEST(CApiTest, CustomAllocator) { ASSERT_EQ(mocked_allocator.NumAllocations(), std::stoll(stats.GetValue("NumAllocs"))); ASSERT_EQ(mocked_allocator.NumReserveAllocations(), std::stoll(stats.GetValue("NumReserves"))); - ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(mocked_allocator.Info()); } // Case 2: Register a custom allocator with an older API version which does not support GetStats. { MockedOrtAllocator mocked_allocator; mocked_allocator.version = 22; - ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + ort_env->RegisterAllocator(&mocked_allocator); Ort::SessionOptions session_options; session_options.AddConfigEntry("session.use_env_allocators", "1"); @@ -81,7 +79,7 @@ TEST(CApiTest, CustomAllocator) { auto stats = allocator.GetStats(); ASSERT_EQ(0, stats.GetKeyValuePairs().size()); - ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(mocked_allocator.Info()); } } #endif diff --git a/onnxruntime/test/shared_lib/test_data_copy.cc b/onnxruntime/test/shared_lib/test_data_copy.cc index 2294bb8d6fdff..872671135fc6a 100644 --- a/onnxruntime/test/shared_lib/test_data_copy.cc +++ b/onnxruntime/test/shared_lib/test_data_copy.cc @@ -31,9 +31,6 @@ using StreamUniquePtr = std::unique_ptrGetApi(ORT_API_VERSION); - #ifdef _WIN32 std::string cuda_lib = "onnxruntime_providers_cuda.dll"; #else @@ -47,10 +44,10 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // register the provider bridge based CUDA EP so allocator and data transfer is available // not all the CIs have the provider library in the expected place so we allow for that const char* ep_registration_name = "ORT CUDA"; - ASSERT_ORTSTATUS_OK(api->RegisterExecutionProviderLibrary(env, ep_registration_name, - ORT_TSTR("onnxruntime_providers_cuda"))); + ort_env->RegisterExecutionProviderLibrary(ep_registration_name, + ORT_TSTR("onnxruntime_providers_cuda")); - const OrtEpDevice* cuda_device = nullptr; + Ort::ConstEpDevice cuda_device{nullptr}; for (const auto& ep_device : ort_env->GetEpDevices()) { std::string vendor{ep_device.EpVendor()}; std::string name = {ep_device.EpName()}; @@ -70,13 +67,11 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // we pass in the CUDA cudaStream_t from the OrtSyncStream via provider options so need to create it upfront. // in the future the stream should be an input to the Session Run. - OrtSyncStream* stream = nullptr; - StreamUniquePtr stream_ptr; + Ort::SyncStream stream{nullptr}; if (use_streams) { - ASSERT_ORTSTATUS_OK(api->CreateSyncStreamForEpDevice(cuda_device, /*options*/ nullptr, &stream)); - stream_ptr = StreamUniquePtr(stream, [api](OrtSyncStream* stream) { api->ReleaseSyncStream(stream); }); + stream = cuda_device.CreateSyncStream(); - size_t stream_addr = reinterpret_cast(api->SyncStream_GetHandle(stream)); + size_t stream_addr = reinterpret_cast(stream.GetHandle()); options.AddConfigEntry("ep.cudaexecutionprovider.user_compute_stream", std::to_string(stream_addr).c_str()); // we explicitly specify user_compute_stream, so why do we also need to set has_user_compute_stream? options.AddConfigEntry("ep.cudaexecutionprovider.has_user_compute_stream", "1"); @@ -87,24 +82,18 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { size_t num_inputs = session.GetInputCount(); // find the input location so we know which inputs can be provided on device. - std::vector input_locations; - input_locations.resize(num_inputs, nullptr); - ASSERT_ORTSTATUS_OK(api->SessionGetMemoryInfoForInputs(session, input_locations.data(), num_inputs)); + std::vector input_locations = session.GetMemoryInfoForInputs(); std::vector cpu_tensors; // info for device copy - std::vector src_tensor_ptrs; - std::vector dst_tensor_ptrs; - - // values we'll call Run with - std::vector input_tensors; + std::vector device_tensors; ASSERT_EQ(num_inputs, 1); // create cpu based input data. Ort::AllocatorWithDefaultOptions cpu_allocator; - std::vector shape{1, 1, 28, 28}; + constexpr const std::array shape{1, 1, 28, 28}; std::vector input_data(28 * 28, 0.5f); Ort::Value input_value = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), input_data.data(), input_data.size(), @@ -112,15 +101,13 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { cpu_tensors.push_back(std::move(input_value)); for (size_t idx = 0; idx < num_inputs; ++idx) { - const OrtMemoryInfo* mem_info = input_locations[idx]; - OrtDeviceMemoryType mem_type = api->MemoryInfoGetDeviceMemType(mem_info); - OrtMemoryInfoDeviceType device_type; - api->MemoryInfoGetDeviceType(mem_info, &device_type); + auto mem_info = input_locations[idx]; + OrtDeviceMemoryType mem_type = mem_info.GetDeviceMemoryType(); + OrtMemoryInfoDeviceType device_type = mem_info.GetDeviceType(); if (device_type == OrtMemoryInfoDeviceType_GPU && mem_type == OrtDeviceMemoryType_DEFAULT) { // copy to device - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(api->GetSharedAllocator(env, mem_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(mem_info); // allocate new on-device memory auto src_shape = cpu_tensors[idx].GetTensorTypeAndShapeInfo().GetShape(); @@ -137,18 +124,12 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); */ - src_tensor_ptrs.push_back(cpu_tensors[idx]); - dst_tensor_ptrs.push_back(device_value); - input_tensors.push_back(std::move(device_value)); - } else { - // input is on CPU accessible memory. move to input_tensors - input_tensors.push_back(std::move(cpu_tensors[idx])); + device_tensors.push_back(std::move(device_value)); } } - if (!src_tensor_ptrs.empty()) { - ASSERT_ORTSTATUS_OK(api->CopyTensors(env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), stream, - src_tensor_ptrs.size())); + if (!device_tensors.empty()) { + ASSERT_CXX_ORTSTATUS_OK(ort_env->CopyTensors(cpu_tensors, device_tensors, stream)); // Stream support is still a work in progress. // @@ -160,18 +141,19 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // iobinding.SynchronizeInputs(); // this doesn't actually require any bound inputs } - std::vector input_names = {"Input3"}; - std::vector output_names = {"Plus214_Output_0"}; + const auto& input_tensors = (!device_tensors.empty()) ? device_tensors : cpu_tensors; + + constexpr const std::array input_names = {"Input3"}; + constexpr const std::array output_names = {"Plus214_Output_0"}; Ort::Value output; session.Run(Ort::RunOptions{}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), &output, 1); - const float* results = nullptr; - ASSERT_ORTSTATUS_OK(api->GetTensorData(output, reinterpret_cast(&results))); + const float* results = output.GetTensorData(); // expected results from the CPU EP. can check/re-create by running with PREFER_CPU. - std::vector expected = { + constexpr const std::array expected = { -0.701670527f, -0.583666623f, 0.0480501056f, @@ -192,7 +174,7 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { run_test(/*use_streams*/ true); run_test(/*use_streams*/ false); - ASSERT_ORTSTATUS_OK(api->UnregisterExecutionProviderLibrary(env, ep_registration_name)); + ort_env->UnregisterExecutionProviderLibrary(ep_registration_name); } #endif // USE_CUDA diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 56cc234a63832..4b3ee5c182534 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3316,13 +3316,9 @@ TEST(CApiTest, model_metadata) { } TEST(CApiTest, get_available_providers) { - const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - int len = 0; - char** providers; - ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr); - ASSERT_GT(len, 0); - ASSERT_STREQ(providers[len - 1], "CPUExecutionProvider"); - ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr); + std::vector providers = Ort::GetAvailableProviders(); + ASSERT_GT(providers.size(), 0); + ASSERT_STREQ(providers.back().c_str(), "CPUExecutionProvider"); } TEST(CApiTest, get_available_providers_cpp) { @@ -3348,8 +3344,6 @@ TEST(CApiTest, get_build_info_string) { } TEST(CApiTest, TestSharedAllocators) { - OrtEnv* env_ptr = (OrtEnv*)(*ort_env); - // prepare inputs std::vector> inputs(1); auto& input = inputs.back(); @@ -3372,9 +3366,7 @@ TEST(CApiTest, TestSharedAllocators) { // CASE 1: We test creating and registering an ORT-internal allocator implementation instance // for sharing between sessions { - OrtMemoryInfo* mem_info = nullptr; - ASSERT_TRUE(api.CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info) == nullptr); - std::unique_ptr rel_info(mem_info, api.ReleaseMemoryInfo); + auto mem_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); OrtArenaCfg* arena_cfg = nullptr; ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); @@ -3382,13 +3374,9 @@ TEST(CApiTest, TestSharedAllocators) { // This creates an ORT-internal allocator instance and registers it in the environment for sharing // NOTE: On x86 builds arenas are not supported and will default to using non-arena based allocator - ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr); - + ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); // Registration is always a replace operation - std::unique_ptr status_releaser( - api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); + ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); { // create session 1 @@ -3414,7 +3402,7 @@ TEST(CApiTest, TestSharedAllocators) { // Remove the registered shared allocator for part 2 of this test // where-in we will register a custom allocator for the same device. - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, mem_info) == nullptr); + ort_env->UnregisterAllocator(mem_info); } // CASE 2: We test registering a custom allocator implementation @@ -3425,15 +3413,9 @@ TEST(CApiTest, TestSharedAllocators) { // need to be aligned for certain devices/build configurations/math libraries. // See docs/C_API.md for details. MockedOrtAllocator custom_allocator; - ASSERT_TRUE(api.RegisterAllocator(env_ptr, &custom_allocator) == nullptr); - + ort_env->RegisterAllocator(&custom_allocator); // Registration is always a replace operation - std::unique_ptr - status_releaser( - api.RegisterAllocator(env_ptr, &custom_allocator), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); - + ort_env->RegisterAllocator(&custom_allocator); { // Keep this scoped to destroy the underlying sessions after use // This should trigger frees in our custom allocator @@ -3472,7 +3454,7 @@ TEST(CApiTest, TestSharedAllocators) { // Remove the registered shared allocator from the global environment // (common to all tests) to prevent its accidental usage elsewhere - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, custom_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(custom_allocator.Info()); // Ensure that the registered custom allocator was indeed used for both sessions // We should have seen 2 allocations per session (one for the sole initializer @@ -3488,22 +3470,18 @@ TEST(CApiTest, TestSharedAllocators) { } #ifdef USE_CUDA { - OrtMemoryInfo* cuda_meminfo = nullptr; - ASSERT_TRUE(api.CreateMemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault, &cuda_meminfo) == nullptr); - std::unique_ptr rel_info(cuda_meminfo, api.ReleaseMemoryInfo); + auto cuda_meminfo = Ort::MemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault); OrtArenaCfg* arena_cfg = nullptr; ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); - std::vector keys, values; - ASSERT_TRUE(api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0) == nullptr); + ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, + cuda_meminfo, {}, arena_cfg); // Registration is always a replace operation - std::unique_ptr status_releaser( - api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); + ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, + cuda_meminfo, {}, arena_cfg); { // create session 1 @@ -3530,7 +3508,7 @@ TEST(CApiTest, TestSharedAllocators) { nullptr); } - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, cuda_meminfo) == nullptr); + ort_env->UnregisterAllocator(cuda_meminfo); } #endif } @@ -3998,8 +3976,7 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr); - OrtAllocator* allocator; - ASSERT_TRUE(api.GetAllocatorWithDefaultOptions(&allocator) == nullptr); + auto allocator = Ort::AllocatorWithDefaultOptions(); char* cuda_options_str = nullptr; ASSERT_TRUE(api.GetCUDAProviderOptionsAsString(rel_cuda_options.get(), allocator, &cuda_options_str) == nullptr); @@ -4015,10 +3992,10 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { ASSERT_TRUE(s.find("cudnn_conv_use_max_workspace=1") != std::string::npos); ASSERT_TRUE(s.find("cudnn_conv1d_pad_to_nc1d") != std::string::npos); - ASSERT_TRUE(api.AllocatorFree(allocator, (void*)cuda_options_str) == nullptr); + allocator.Free(cuda_options_str); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2(static_cast(session_options), rel_cuda_options.get()) == nullptr); + session_options.AppendExecutionProvider_CUDA_V2(*rel_cuda_options); // if session creation passes, model loads fine std::basic_string model_uri = MODEL_URI; diff --git a/onnxruntime/test/util/include/api_asserts.h b/onnxruntime/test/util/include/api_asserts.h index 423135f96fbcd..946782752e4bd 100644 --- a/onnxruntime/test/util/include/api_asserts.h +++ b/onnxruntime/test/util/include/api_asserts.h @@ -10,36 +10,29 @@ #include "core/session/onnxruntime_cxx_api.h" // asserts for the public API -#define ASSERT_ORTSTATUS_OK(function) \ - do { \ - OrtStatusPtr _tmp_status = (function); \ - ASSERT_EQ(_tmp_status, nullptr) << Ort::GetApi().GetErrorMessage(_tmp_status); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define ASSERT_ORTSTATUS_OK(function) \ + do { \ + Ort::Status _tmp_status = (function); \ + ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) -#define EXPECT_ORTSTATUS_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - EXPECT_EQ(_tmp_status, nullptr) << Ort::GetApi().GetErrorMessage(_tmp_status); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define EXPECT_ORTSTATUS_OK(api, function) \ + do { \ + Ort::Status _tmp_status = (api->function); \ + EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) -#define ASSERT_ORTSTATUS_NOT_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - ASSERT_NE(_tmp_status, nullptr); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define ASSERT_ORTSTATUS_NOT_OK(api, function) \ + do { \ + Ort::Status _tmp_status = (api->function); \ + ASSERT_TRUE(_tmp_status.IsOK()); \ } while (false) -#define EXPECT_ORTSTATUS_NOT_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - EXPECT_NE(_tmp_status, nullptr); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define EXPECT_ORTSTATUS_NOT_OK(api, function) \ + do { \ + Ort::Status _tmp_status = (api->function); \ + EXPECT_FALSE(_tmp_status.IsOK()); \ } while (false) -#define ASSERT_CXX_ORTSTATUS_OK(function) \ - do { \ - Ort::Status _tmp_status = (function); \ - ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ - } while (false) +#define ASSERT_CXX_ORTSTATUS_OK(function) \ + ASSERT_ORTSTATUS_OK(function) From 2da60c0ef104a2b55cfc7e1afa84c5715d765be2 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 14 Aug 2025 10:18:01 -0700 Subject: [PATCH 02/12] XX --- .../core/session/onnxruntime_cxx_api.h | 28 ++++- .../core/session/onnxruntime_cxx_inline.h | 34 +++++ .../tools/tensorrt/perf/mem_test/main.cpp | 18 +-- onnxruntime/test/perftest/ort_test_session.cc | 25 +--- onnxruntime/test/providers/cpu/model_tests.cc | 9 +- onnxruntime/test/shared_lib/test_inference.cc | 118 ++++++------------ 6 files changed, 111 insertions(+), 121 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index cee2300daae54..e01da54a7c053 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -574,10 +574,14 @@ ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE(Node); ORT_DEFINE_RELEASE(Graph); ORT_DEFINE_RELEASE(Model); -ORT_DEFINE_RELEASE(KeyValuePairs) +ORT_DEFINE_RELEASE(KeyValuePairs); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); +// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, +// but the struct has V2 in its name to indicate that it is the second version of the options. +inline void OrtRelease(OrtTensorRTProviderOptionsV2* ptr) { GetApi().ReleaseTensorRTProviderOptions(ptr); } + #undef ORT_DEFINE_RELEASE #undef ORT_DEFINE_RELEASE_FROM_API_STRUCT @@ -629,6 +633,7 @@ struct Base { } constexpr operator contained_type*() const noexcept { return p_; } + constexpr contained_type& operator*() const noexcept { return *p_; } /// \brief Relinquishes ownership of the contained C object pointer /// The underlying object is not destroyed @@ -673,6 +678,7 @@ struct Base> { } constexpr operator contained_type*() const noexcept { return p_; } + constexpr contained_type& operator*() const noexcept { return *p_; } protected: contained_type* p_{}; @@ -701,6 +707,7 @@ struct TypeInfo; struct Session; struct SessionOptions; struct SyncStream; +struct TensorRTProviderOptions; struct Value; struct ValueInfo; @@ -758,6 +765,25 @@ struct ThreadingOptions : detail::Base { ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); }; +/** \brief The TensorRTOptions (V2) + * + * Used to pass options to TRT EP + */ +struct TensorRTProviderOptions : detail::Base { + TensorRTProviderOptions(std::nullptr_t) {} + /// \brief Wraps OrtApi::CreateTensorRTProviderOptionsV2 + TensorRTProviderOptions(); + ///< Wrapper around OrtApi::UpdateTensorRTProviderOptions + void Update(const std::unordered_map& options); + ///< Wrapper around OrtApi::UpdateTensorRTProviderOptions + void UpdateWithValue(const char* key, void* value); + + ///< Wrapper around OrtApi::GetTensorRTProviderOptionsByName + void* GetptionByName(const char* name) const; + ///< Wrapper around OrtApi::GetTensorRTProviderOptionsAsString + std::string GetTensorRTProviderOptionsAsString() const; +}; + namespace detail { template struct KeyValuePairsImpl : Ort::detail::Base { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 201278c5c132b..eef6cc7972001 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -488,6 +488,40 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom return *this; } +inline TensorRTProviderOptions::TensorRTProviderOptions() { + ThrowOnError(GetApi().CreateTensorRTProviderOptions(&this->p_)); +} + +inline void TensorRTProviderOptions::Update(const std::unordered_map& options) { + std::vector keys; + std::vector values; + keys.reserve(options.size()); + values.reserve(options.size()); + for (const auto& kv : options) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second.c_str()); + } + ThrowOnError(GetApi().UpdateTensorRTProviderOptions(p_, keys.data(), values.data(), options.size())); +} + +inline void TensorRTProviderOptions::UpdateWithValue(const char* key, void* value) { + ThrowOnError(GetApi().UpdateTensorRTProviderOptionsWithValue(p_, key, value)); +} + +inline void* TensorRTProviderOptions::GetptionByName(const char* name) const { + void* value = nullptr; + ThrowOnError(GetApi().GetTensorRTProviderOptionsByName(p_, name, &value)); + return value; +} + +inline std::string TensorRTProviderOptions::GetTensorRTProviderOptionsAsString() const { + AllocatorWithDefaultOptions allocator; + char* options_str = nullptr; + ThrowOnError(GetApi().GetTensorRTProviderOptionsAsString(p_, allocator, &options_str)); + std::unique_ptr options_str_g(options_str, detail::AllocatedFree(allocator)); + return std::string(options_str); +} + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { diff --git a/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp b/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp index ec30b8ba0985d..2550fde338cd5 100644 --- a/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp +++ b/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp @@ -10,21 +10,15 @@ void run_ort_trt2() { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); const char* model_path = "squeezenet.onnx"; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), - rel_trt_options.get())); + Ort::TensorRTProviderOptions tensorrt_options; + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); std::cout << "Running ORT TRT EP with default provider options" << std::endl; @@ -127,7 +121,7 @@ void run_ort_trt2() { void run_ort_trt() { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; + Ort::TensorRTProviderOptions tensorrt_options; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); @@ -136,11 +130,7 @@ void run_ort_trt() { const char* model_path = "/data/ep-perf-models/onnx-zoo-models/squeezenet1.0-7/squeezenet/model.onnx"; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), - rel_trt_options.get())); + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); std::cout << "Running ORT TRT EP with default provider options" << std::endl; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 7156a1eb5c347..2da6f5e6b9a04 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -235,12 +235,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - std::vector option_keys, option_values; + Ort::TensorRTProviderOptions tensorrt_options; // used to keep all option keys and value strings alive std::list buffer; @@ -250,25 +245,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif ParseSessionConfigs(ov_string, provider_options); - for (const auto& provider_option : provider_options) { - option_keys.push_back(provider_option.first.c_str()); - option_values.push_back(provider_option.second.c_str()); - } - Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, - option_keys.data(), option_values.data(), option_keys.size())); - if (!status.IsOK()) { - OrtAllocator* allocator; - char* options; - Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); - Ort::ThrowOnError(api.GetTensorRTProviderOptionsAsString(tensorrt_options, allocator, &options)); - ORT_THROW("[ERROR] [TensorRT] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), - "\nSupported options are:\n", options); - } - + tensorrt_options.Update(provider_options); session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); OrtCUDAProviderOptions cuda_options; - cuda_options.device_id = tensorrt_options->device_id; + cuda_options.device_id = static_cast(tensorrt_options)->device_id; cuda_options.cudnn_conv_algo_search = static_cast(performance_test_config.run_config.cudnn_conv_algo); cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index ed9429d5eec54..6dd971fd84f82 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -208,13 +208,10 @@ TEST_P(ModelTest, Run) { #endif else if (provider_name == "tensorrt") { if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) { - OrtTensorRTProviderOptionsV2 params; - ortso.AppendExecutionProvider_TensorRT_V2(params); + Ort::TensorRTProviderOptions params; + ortso.AppendExecutionProvider_TensorRT_V2(*params); } else { - OrtTensorRTProviderOptionsV2* ep_option = nullptr; - ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateTensorRTProviderOptions(&ep_option)); - std::unique_ptr - rel_cuda_options(ep_option, &OrtApis::ReleaseTensorRTProviderOptions); + Ort::TensorRTProviderOptions ep_option; ortso.AppendExecutionProvider_TensorRT_V2(*ep_option); } // Enable CUDA fallback diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4b3ee5c182534..ef8a60d8ad240 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2175,7 +2175,7 @@ TEST(CApiTest, io_binding) { TEST(CApiTest, io_binding_cuda) { Ort::SessionOptions session_options; #ifdef USE_TENSORRT - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); + session_options.AppendExecutionProvider_TensorRT({}); #else OrtCUDAProviderOptionsV2* options; Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); @@ -2376,22 +2376,17 @@ TEST(CApiTest, io_binding_qnn_htp_shared) { #if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) || defined(USE_DML) TEST(CApiTest, basic_cuda_graph) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; #if defined(USE_TENSORRT) // Enable cuda graph in TRT provider option. - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - std::vector keys{"trt_cuda_graph_enable"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + Ort::TensorRTProviderOptions trt_options; + std::unordered_map trt_options_map = {{"trt_cuda_graph_enable", + "1"}}; + trt_options.Update(trt_options_map); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; @@ -3368,10 +3363,7 @@ TEST(CApiTest, TestSharedAllocators) { { auto mem_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); - + Ort::ArenaCfg arena_cfg(0, -1, -1, -1); // This creates an ORT-internal allocator instance and registers it in the environment for sharing // NOTE: On x86 builds arenas are not supported and will default to using non-arena based allocator ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); @@ -3472,10 +3464,7 @@ TEST(CApiTest, TestSharedAllocators) { { auto cuda_meminfo = Ort::MemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault); - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); - + Ort::ArenaCfg arena_cfg(0, -1, -1, -1); ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, cuda_meminfo, {}, arena_cfg); @@ -3696,24 +3685,16 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #ifdef USE_TENSORRT TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { - const auto& api = Ort::GetApi(); - // Test input tensor which is shape tensor with explicit trt profile shapes Ort::SessionOptions session_options; - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - - const char* trt_profile_min_shapes = "data:2x2,shape:4x1"; - const char* trt_profile_max_shapes = "data:2x2,shape:4x1"; - const char* trt_profile_opt_shapes = "data:2x2,shape:4x1"; - std::vector keys{"trt_profile_min_shapes", "trt_profile_max_shapes", "trt_profile_opt_shapes"}; - std::vector values{trt_profile_min_shapes, trt_profile_max_shapes, trt_profile_opt_shapes}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); + Ort::TensorRTProviderOptions trt_options; + + std::unordered_map trt_options_map = { + {"trt_profile_min_shapes", "data:2x2,shape:4x1"}, + {"trt_profile_max_shapes", "data:2x2,shape:4x1"}, + {"trt_profile_opt_shapes", "data:2x2,shape:4x1"}}; + trt_options.Update(trt_options_map); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); auto model_path = ORT_TSTR("testdata/trt_reshape.onnx"); @@ -3736,37 +3717,24 @@ TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { // Test input tensor which is shape tensor with implicit trt profile shapes Ort::SessionOptions session_options_2; - OrtTensorRTProviderOptionsV2* trt_options_2; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options_2) == nullptr); - std::unique_ptr - rel_trt_options_2(trt_options_2, api.ReleaseTensorRTProviderOptions); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options_2), - rel_trt_options_2.get()) == nullptr); + Ort::TensorRTProviderOptions trt_options_2; + session_options_2.AppendExecutionProvider_TensorRT_V2(*trt_options_2); Ort::Session session_2(*ort_env, model_path, session_options_2); - session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); + session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, std::size(output_names)); } TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { - const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; - - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + Ort::TensorRTProviderOptions trt_options; // updating provider option with user provided compute stream cudaStream_t compute_stream = nullptr; - void* user_compute_stream = nullptr; cudaStreamCreate(&compute_stream); - ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + trt_options.UpdateWithValue("user_compute_stream", compute_stream); + void* user_compute_stream = trt_options.GetptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -3879,36 +3847,30 @@ class CApiTensorRTTest : public testing::Test, public ::testing::WithParamInterf TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { std::string param = GetParam(); size_t pos = param.find("="); - std::string option_name = param.substr(0, pos); - std::string option_value = param.substr(pos + 1); + const std::string option_name = param.substr(0, pos); + const std::string option_value = param.substr(pos + 1); ASSERT_NE(pos, std::string::npos); - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* trt_options; - OrtAllocator* allocator; - char* trt_options_str; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - const char* engine_cache_path = "./trt_engine_folder"; - std::vector keys{"device_id", "has_user_compute_stream", "trt_fp16_enable", "trt_int8_enable", "trt_engine_cache_enable", - "trt_engine_cache_path", option_name.c_str()}; - - std::vector values{"0", "0", "1", "0", "1", - engine_cache_path, option_value.c_str()}; + Ort::TensorRTProviderOptions trt_options; + std::unordered_map trt_options_map = { + {"device_id", "0"}, + {"has_user_compute_stream", "0"}, + {"trt_fp16_enable", "1"}, + {"trt_int8_enable", "0"}, + {"trt_engine_cache_enable", "1"}, + {"trt_engine_cache_path", engine_cache_path}, + {option_name, option_value}}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + trt_options.Update(trt_options_map); - ASSERT_TRUE(api.GetAllocatorWithDefaultOptions(&allocator) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsAsString(rel_trt_options.get(), allocator, &trt_options_str) == nullptr); - std::string s(trt_options_str); - ASSERT_TRUE(s.find(engine_cache_path) != std::string::npos); - ASSERT_TRUE(s.find(param.c_str()) != std::string::npos); - ASSERT_TRUE(api.AllocatorFree(allocator, (void*)trt_options_str) == nullptr); + std::string trt_options_str = trt_options.GetTensorRTProviderOptionsAsString(); + ASSERT_NE(trt_options_str.find(engine_cache_path), std::string::npos); + ASSERT_NE(trt_options_str.find(param), std::string::npos); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), rel_trt_options.get()) == nullptr); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); // simple inference test // prepare inputs From bd82a8eb81996dd01d121e40e435edc3feb7f1da Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 14 Aug 2025 17:33:28 -0700 Subject: [PATCH 03/12] CPU build passes --- .../core/session/onnxruntime_cxx_api.h | 31 +++++ .../core/session/onnxruntime_cxx_inline.h | 44 +++++++ .../test/contrib_ops/skiplayernorm_op_test.cc | 13 +- .../global_thread_pools/test_inference.cc | 3 +- onnxruntime/test/perftest/ort_test_session.cc | 46 ++----- onnxruntime/test/providers/cpu/model_tests.cc | 33 ++--- .../test/providers/qnn/qnn_basic_test.cc | 4 +- onnxruntime/test/shared_lib/test_inference.cc | 121 ++++++------------ .../test/shared_lib/test_model_loading.cc | 3 +- .../test/shared_lib/test_session_options.cc | 21 ++- 10 files changed, 157 insertions(+), 162 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index e01da54a7c053..aea344036983b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -575,12 +575,14 @@ ORT_DEFINE_RELEASE(Node); ORT_DEFINE_RELEASE(Graph); ORT_DEFINE_RELEASE(Model); ORT_DEFINE_RELEASE(KeyValuePairs); +ORT_DEFINE_RELEASE(PrepackedWeightsContainer); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. inline void OrtRelease(OrtTensorRTProviderOptionsV2* ptr) { GetApi().ReleaseTensorRTProviderOptions(ptr); } +inline void OrtRelease(OrtCUDAProviderOptionsV2* ptr) { GetApi().ReleaseCUDAProviderOptions(ptr); } #undef ORT_DEFINE_RELEASE #undef ORT_DEFINE_RELEASE_FROM_API_STRUCT @@ -704,6 +706,7 @@ struct Model; struct Node; struct ModelMetadata; struct TypeInfo; +struct PrepackedWeightsContainer; struct Session; struct SessionOptions; struct SyncStream; @@ -784,6 +787,33 @@ struct TensorRTProviderOptions : detail::Base { std::string GetTensorRTProviderOptionsAsString() const; }; +/** \brief The CUDAProviderOptions (V2) + * + * Used to pass options to CUDA EP + */ + +struct CUDAProviderOptions : detail::Base { + CUDAProviderOptions(std::nullptr_t) {} + /// \brief Wraps OrtApi::CreateCUDAProviderOptions + CUDAProviderOptions(); + ///< Wrapper around OrtApi::UpdateCUDAProviderOptions + void Update(const std::unordered_map& options); + ///< Wrapper around OrtApi::GetCUDAProviderOptionsAsString + std::string GetCUDAProviderOptionsAsString() const; + ///< Wrapper around OrtApi::UpdateCUDAProviderOptionsWithValue + void UpdateWithValue(const char* key, void* value); + ///< Wrapper around OrtApi::GetCUDAProviderOptionsByName + void* GetOptionByName(const char* name) const; +}; + +struct PrepackedWeightsContainer : detail::Base { + using Base = detail::Base; + explicit PrepackedWeightsContainer(std::nullptr_t) {} ///< No instance is created + explicit PrepackedWeightsContainer(OrtPrepackedWeightsContainer* p) : Base{p} {} ///< Take ownership of a pointer created by C API + /// \brief Wraps OrtApi::CreatePrepackedWeightsContainer + PrepackedWeightsContainer(); +}; + namespace detail { template struct KeyValuePairsImpl : Ort::detail::Base { @@ -1230,6 +1260,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { const std::vector& external_initializer_file_buffer_array, const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory + SessionOptionsImpl& AppendExecutionProvider_CPU(int use_arena); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CPU SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index eef6cc7972001..be392d8518d5f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -522,6 +522,44 @@ inline std::string TensorRTProviderOptions::GetTensorRTProviderOptionsAsString() return std::string(options_str); } +inline CUDAProviderOptions::CUDAProviderOptions() { + ThrowOnError(GetApi().CreateCUDAProviderOptions(&this->p_)); +} + +inline void CUDAProviderOptions::Update(const std::unordered_map& options) { + std::vector keys; + std::vector values; + keys.reserve(options.size()); + values.reserve(options.size()); + for (const auto& kv : options) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second.c_str()); + } + ThrowOnError(GetApi().UpdateCUDAProviderOptions(p_, keys.data(), values.data(), options.size())); +} + +inline std::string CUDAProviderOptions::GetCUDAProviderOptionsAsString() const { + AllocatorWithDefaultOptions allocator; + char* options_str = nullptr; + ThrowOnError(GetApi().GetCUDAProviderOptionsAsString(p_, allocator, &options_str)); + std::unique_ptr options_str_g(options_str, detail::AllocatedFree(allocator)); + return std::string(options_str); +} + +inline void CUDAProviderOptions::UpdateWithValue(const char* key, void* value) { + ThrowOnError(GetApi().UpdateCUDAProviderOptionsWithValue(p_, key, value)); +} + +inline void* CUDAProviderOptions::GetOptionByName(const char* name) const { + void* value = nullptr; + ThrowOnError(GetApi().GetCUDAProviderOptionsByName(p_, name, &value)); + return value; +} + +inline PrepackedWeightsContainer::PrepackedWeightsContainer() { + ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_)); +} + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { @@ -1155,6 +1193,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFrom return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CPU(int use_arena) { + ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 4e8d1b9f016f0..df83815cc29ea 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -146,15 +146,10 @@ static void RunOneTest( execution_providers.push_back(DefaultRocmExecutionProvider()); } else { if (strict) { - const auto& api = Ort::GetApi(); - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_skip_layer_norm_strict_mode"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); - execution_providers.push_back(CudaExecutionProviderWithOptions(std::move(rel_cuda_options.get()))); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options = {{"enable_skip_layer_norm_strict_mode", "1"}}; + cuda_options.Update(options); + execution_providers.push_back(CudaExecutionProviderWithOptions(std::move(cuda_options))); } else { execution_providers.push_back(DefaultCudaExecutionProvider()); } diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index c6d958536f488..324394798863c 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -74,8 +74,7 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) if (provider_type == 1) { #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); std::cout << "Running simple inference with cuda provider" << std::endl; #else diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 2da6f5e6b9a04..f1a40b1da8651 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -179,53 +179,29 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - const auto& api = Ort::GetApi(); - OrtCUDAProviderOptionsV2* cuda_options; - Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - std::vector option_keys, option_values; - // used to keep all option keys and value strings alive - std::list buffer; - buffer.emplace_back("cudnn_conv_algo_search"); - option_keys.push_back(buffer.back().c_str()); + Ort::CUDAProviderOptions cuda_options; + + const char* config_val = nullptr; switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: - buffer.emplace_back("EXHAUSTIVE"); + config_val = "EXHAUSTIVE"; break; case 1: - buffer.emplace_back("HEURISTIC"); + config_val = "HEURISTIC"; break; default: - buffer.emplace_back("DEFAULT"); + config_val = "DEFAULT"; break; } - option_values.push_back(buffer.back().c_str()); + provider_options.emplace("cudnn_conv_algo_search", config_val); + provider_options.emplace("do_copy_in_default_stream", + (!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0")); - buffer.emplace_back("do_copy_in_default_stream"); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); - option_values.push_back(buffer.back().c_str()); - -#ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); -#else - std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; -#endif + ParseSessionConfigs(ov_string, provider_options); - for (const auto& provider_option : provider_options) { - option_keys.push_back(provider_option.first.c_str()); - option_values.push_back(provider_option.second.c_str()); - } + cuda_options.Update(provider_options); - Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, - option_keys.data(), option_values.data(), option_keys.size())); - if (!status.IsOK()) { - OrtAllocator* allocator; - char* options; - Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); - Ort::ThrowOnError(api.GetCUDAProviderOptionsAsString(cuda_options, allocator, &options)); - ORT_THROW("[ERROR] [CUDA] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), - "\nSupported options are:\n", options); - } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); if (performance_test_config.run_config.enable_cuda_io_binding) { device_memory_name_ = CUDA; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 6dd971fd84f82..635c97291d4bd 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -179,17 +179,14 @@ TEST_P(ModelTest, Run) { ortso.SetLogId(ToUTF8String(test_case_name).c_str()); ortso.SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); if (provider_name == "cuda") { - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); - std::unique_ptr rel_cuda_options( - cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; - std::vector keys{"device_id", "use_tf32"}; - std::vector values; std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); - values.push_back(device_id.empty() ? "0" : device_id.c_str()); - values.push_back("0"); - ASSERT_CXX_ORTSTATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + + std::unordered_map options; + options["device_id"] = (device_id.empty() ? "0" : device_id.c_str()); + options["use_tf32"] = "0"; // Disable TF32 for CUDA provider + cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "rocm") { @@ -208,24 +205,20 @@ TEST_P(ModelTest, Run) { #endif else if (provider_name == "tensorrt") { if (test_case_name.find(ORT_TSTR("FLOAT16")) != std::string::npos) { - Ort::TensorRTProviderOptions params; - ortso.AppendExecutionProvider_TensorRT_V2(*params); + OrtTensorRTProviderOptionsV2 params; + ortso.AppendExecutionProvider_TensorRT_V2(params); } else { Ort::TensorRTProviderOptions ep_option; ortso.AppendExecutionProvider_TensorRT_V2(*ep_option); } // Enable CUDA fallback - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); - std::unique_ptr rel_cuda_options( - cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; - std::vector keys{"device_id", "use_tf32"}; - std::vector values; std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); - values.push_back(device_id.empty() ? "0" : device_id.c_str()); - values.push_back("0"); - ASSERT_CXX_ORTSTATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + std::unordered_map options; + options["device_id"] = (device_id.empty() ? "0" : device_id.c_str()); + options["use_tf32"] = "0"; // Disable TF32 for CUDA provider + cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "migraphx") { diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index fbe6b0569202d..e3cc5e7c227b8 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -232,7 +232,7 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { so.AppendExecutionProvider("QNN", options); // Invalid! Adds CPU EP to session, but also disables CPU fallback. - Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + so.AppendExecutionProvider_CPU(1); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "constant_floats.onnx"; @@ -285,7 +285,7 @@ TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { so.AppendExecutionProvider("QNN", options); - Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + so.AppendExecutionProvider_CPU(1); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv_qdq_external_ini.onnx"; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ef8a60d8ad240..4b98937e849cd 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1973,7 +1973,7 @@ static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model TEST(CApiTest, get_allocator_cpu) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); Ort::Allocator cpu_allocator(session, info_cpu); @@ -2018,8 +2018,7 @@ TEST(CApiTest, get_allocator_cpu) { #ifdef USE_CUDA TEST(CApiTest, get_allocator_cuda) { Ort::SessionOptions session_options; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); @@ -2101,7 +2100,7 @@ TEST(CApiTest, get_allocator_qnn_htp_shared) { TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); @@ -2177,9 +2176,8 @@ TEST(CApiTest, io_binding_cuda) { #ifdef USE_TENSORRT session_options.AppendExecutionProvider_TensorRT({}); #else - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); - session_options.AppendExecutionProvider_CUDA_V2(*options); + Ort::CUDAProviderOptions cuda_options; + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); @@ -2389,17 +2387,12 @@ TEST(CApiTest, basic_cuda_graph) { #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options_map = {{"enable_cuda_graph", + "1"}}; + cuda_options.Update(options_map); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); #elif defined(USE_ROCM) // Enable hip graph in rocm provider option. OrtROCMProviderOptions* rocm_options = nullptr; @@ -2694,7 +2687,7 @@ static void RunWithCudaGraphAnnotation(T& cg_data, } TEST(CApiTest, basic_cuda_graph_with_annotation) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; #ifdef USE_DML @@ -2707,17 +2700,11 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options_map = {{"enable_cuda_graph", "1"}}; + cuda_options.Update(options_map); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #elif defined(USE_ROCM) // Enable hip graph in rocm provider option. @@ -2766,21 +2753,15 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { #ifndef REDUCED_OPS_BUILD #if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + const std::unordered_map options_map = {{"enable_cuda_graph", "1"}}; + cuda_options.Update(options_map); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); @@ -3356,8 +3337,6 @@ TEST(CApiTest, TestSharedAllocators) { // Turn on sharing of the allocator between sessions session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1"); - const auto& api = Ort::GetApi(); - // CASE 1: We test creating and registering an ORT-internal allocator implementation instance // for sharing between sessions { @@ -3525,16 +3504,10 @@ TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { Ort::Value val = Ort::Value::CreateTensor(mem_info, data, data_len, shape, shape_len); session_options.AddInitializer("W", val); - const auto& api = Ort::GetApi(); - - OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr; - ASSERT_TRUE(api.CreatePrepackedWeightsContainer(&prepacked_weights_container) == nullptr); - std::unique_ptr - rel_prepacked_weights_container(prepacked_weights_container, api.ReleasePrepackedWeightsContainer); - auto default_allocator = std::make_unique(); // create session 1 (using model path) + Ort::PrepackedWeightsContainer prepacked_weights_container; Ort::Session session1(*ort_env, MATMUL_MODEL_URI, session_options, prepacked_weights_container); RunSession(default_allocator.get(), session1, @@ -3631,12 +3604,11 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { Ort::SessionOptions session_options; #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); #else // arena is enabled but the sole initializer will still be allocated from non-arena memory - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); #endif // disable using arena for the sole initializer in the model @@ -3913,39 +3885,32 @@ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest, // This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions/UpdateCUDAProviderOptionsWithValue APIs to configure and create a CUDA Execution Provider instance TEST(CApiTest, TestConfigureCUDAProviderOptions) { - const auto& api = Ort::GetApi(); - - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; // Only test updating OrtCUDAProviderOptionsV2 instance with user provided compute stream not running the inference cudaStream_t compute_stream = nullptr; void* user_compute_stream = nullptr; cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); - ASSERT_TRUE(api.UpdateCUDAProviderOptionsWithValue(rel_cuda_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetCUDAProviderOptionsByName(rel_cuda_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + cuda_options.UpdateWithValue("user_compute_stream", compute_stream); + user_compute_stream = cuda_options.GetOptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); cudaStreamDestroy(compute_stream); - std::vector keys{ - "device_id", "has_user_compute_stream", "gpu_mem_limit", "arena_extend_strategy", - "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"}; - - std::vector values{ - "0", "0", "1024", "kSameAsRequested", - "DEFAULT", "1", "1"}; + std::unordered_map cuda_options_map = { + {"device_id", "0"}, + {"has_user_compute_stream", "0"}, + {"gpu_mem_limit", "1024"}, + {"arena_extend_strategy", "kSameAsRequested"}, + {"cudnn_conv_algo_search", "DEFAULT"}, + {"do_copy_in_default_stream", "1"}, + {"cudnn_conv_use_max_workspace", "1"}, + {"cudnn_conv1d_pad_to_nc1d", "1"}}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr); + cuda_options.Update(cuda_options_map); auto allocator = Ort::AllocatorWithDefaultOptions(); - char* cuda_options_str = nullptr; - ASSERT_TRUE(api.GetCUDAProviderOptionsAsString(rel_cuda_options.get(), allocator, &cuda_options_str) == nullptr); - std::string s; - if (cuda_options_str != nullptr) { - s = std::string(cuda_options_str, strnlen(cuda_options_str, 2048)); - } + std::string s = cuda_options.GetCUDAProviderOptionsAsString(); ASSERT_TRUE(s.find("device_id=0") != std::string::npos); ASSERT_TRUE(s.find("gpu_mem_limit=1024") != std::string::npos); ASSERT_TRUE(s.find("arena_extend_strategy=kSameAsRequested") != std::string::npos); @@ -3954,10 +3919,8 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { ASSERT_TRUE(s.find("cudnn_conv_use_max_workspace=1") != std::string::npos); ASSERT_TRUE(s.find("cudnn_conv1d_pad_to_nc1d") != std::string::npos); - allocator.Free(cuda_options_str); - Ort::SessionOptions session_options; - session_options.AppendExecutionProvider_CUDA_V2(*rel_cuda_options); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); // if session creation passes, model loads fine std::basic_string model_uri = MODEL_URI; @@ -4056,9 +4019,8 @@ TEST(CApiTest, GitHubIssue10179) { auto load_model_thread_fn = []() { try { const auto* model_path = MODEL_URI; - Ort::SessionOptions session_options{}; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::SessionOptions session_options; + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; } catch (const std::exception& e) { @@ -4089,8 +4051,7 @@ TEST(CApiTest, GitHubIssue10179) { TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { const auto* model_path = SEQUENCE_MODEL_URI_2; Ort::SessionOptions session_options{}; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index 89b12ec61649e..7268c351877f3 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -60,8 +60,7 @@ TEST(CApiTest, model_from_array) { create_session(so); #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; so.AppendExecutionProvider_CUDA_V2(*options); create_session(so); #endif diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index 3fbb294e1af49..de7115ac5189c 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -54,20 +54,17 @@ TEST(CApiTest, session_options_provider_interface_fail_add_openvino) { #if defined(USE_CUDA_PROVIDER_INTERFACE) // Test that loading CUDA EP when only the interface is built (but not the full EP) fails. TEST(CApiTest, session_options_provider_interface_fail_add_cuda) { - const OrtApi& api = Ort::GetApi(); Ort::SessionOptions session_options; - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - Ort::Status status1 = Ort::Status{api.CreateCUDAProviderOptions(&cuda_options)}; - ASSERT_TRUE(status1.IsOK()); - - Ort::Status status2 = Ort::Status{api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, - cuda_options)}; - ASSERT_FALSE(status2.IsOK()); - EXPECT_EQ(status2.GetErrorCode(), ORT_FAIL); - EXPECT_THAT(status2.GetErrorMessage(), testing::HasSubstr("Failed to load")); - - api.ReleaseCUDAProviderOptions(cuda_options); + Ort::CUDAProviderOptions cuda_options; + bool thrown = false; + try { + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); + ASSERT_TRUE(false) << "Appending CUDA options have thrown exception"; + } catch (const Ort::Exception& ex) { + ASSERT_THAT(ex.what(), testing::HasSubstr("Failed to load")); + thrown = true; + } } #endif // defined(USE_CUDA_PROVIDER_INTERFACE) From d9dafb7189bb08fdc239aec6b0711255553c582f Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Aug 2025 12:28:32 -0700 Subject: [PATCH 04/12] Add C++ wrapper for CreateArenaCfgV2 --- .../core/session/onnxruntime_cxx_api.h | 18 +++++++++++++++--- .../core/session/onnxruntime_cxx_inline.h | 12 ++++++++++++ onnxruntime/test/shared_lib/test_inference.cc | 16 +++++++++------- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index aea344036983b..8b51b15a02f87 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -791,7 +791,6 @@ struct TensorRTProviderOptions : detail::Base { * * Used to pass options to CUDA EP */ - struct CUDAProviderOptions : detail::Base { CUDAProviderOptions(std::nullptr_t) {} /// \brief Wraps OrtApi::CreateCUDAProviderOptions @@ -806,10 +805,17 @@ struct CUDAProviderOptions : detail::Base { void* GetOptionByName(const char* name) const; }; +/** \brief The PrepackedWeightsContainer + * + * Create only and pass to Ort::Session constructor for multiple sessions + * to share pre-packed weights. + */ struct PrepackedWeightsContainer : detail::Base { using Base = detail::Base; - explicit PrepackedWeightsContainer(std::nullptr_t) {} ///< No instance is created - explicit PrepackedWeightsContainer(OrtPrepackedWeightsContainer* p) : Base{p} {} ///< Take ownership of a pointer created by C API + ///< No instance is created + explicit PrepackedWeightsContainer(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit PrepackedWeightsContainer(OrtPrepackedWeightsContainer* p) : Base{p} {} /// \brief Wraps OrtApi::CreatePrepackedWeightsContainer PrepackedWeightsContainer(); }; @@ -2340,6 +2346,12 @@ struct ArenaCfg : detail::Base { * See docs/C_API.md for details on what the following parameters mean and how to choose these values */ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); + + /** + * Wraps Ort::CreateArenaCfgV2 + * See C API for details on what the following parameters mean and how to choose these values + */ + explicit ArenaCfg(const std::unordered_map& arena_config); }; // diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index be392d8518d5f..b640fc12e252b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -449,6 +449,18 @@ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); } +inline ArenaCfg::ArenaCfg(const std::unordered_map& arena_config) { + std::vector keys; + std::vector values; + keys.reserve(arena_config.size()); + values.reserve(arena_config.size()); + for (const auto& kv : arena_config) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second); + } + ThrowOnError(GetApi().CreateArenaCfgV2(keys.data(), values.data(), arena_config.size(), &p_)); +} + inline ThreadingOptions::ThreadingOptions() { ThrowOnError(GetApi().CreateThreadingOptions(&p_)); } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 4b98937e849cd..c1f1fc062faf8 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3624,16 +3624,18 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { // Usage example showing how to use CreateArenaCfgV2() API to configure the default memory CUDA arena allocator TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { - const auto& api = Ort::GetApi(); - Ort::SessionOptions session_options; - const char* keys[] = {"max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes", "max_power_of_two_extend_bytes"}; - const size_t values[] = {0 /*let ort pick default max memory*/, 0, 1024, 0, 256, 1L << 24}; + const std::unordered_map config_map = { + {"max_mem", 0}, // let ort pick default max memory + {"arena_extend_strategy", 0}, // use default extend strategy + {"initial_chunk_size_bytes", 1024}, // initial chunk size in bytes + {"max_dead_bytes_per_chunk", 0}, // no dead bytes per chunk + {"initial_growth_chunk_size_bytes", 256}, // initial growth chunk size in bytes + {"max_power_of_two_extend_bytes", 1L << 24} // max power of two extend bytes + }; - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfgV2(keys, values, 5, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + Ort::ArenaCfg arena_cfg(config_map); OrtCUDAProviderOptions cuda_provider_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(nullptr); cuda_provider_options.default_memory_arena_cfg = arena_cfg; From 1f0c4fadca08eeca04d4c077b1ef5789b61f415a Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Aug 2025 13:00:37 -0700 Subject: [PATCH 05/12] Fix up subscript --- .../core/session/onnxruntime_cxx_inline.h | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index b640fc12e252b..04bef6d8600d2 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1516,13 +1516,10 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs AllocatorWithDefaultOptions allocator; auto num_outputs = GetOutputCount(); std::vector mem_infos; - mem_infos.reserve(num_outputs); + mem_infos.resize(num_outputs); - const OrtMemoryInfo* mem_info_ptrs; - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, &mem_info_ptrs, num_outputs)); - for (size_t i = 0; i < num_outputs; ++i) { - mem_infos.emplace_back(mem_info_ptrs[i]); - } + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, reinterpret_cast(&mem_infos[0], + num_outputs)); return mem_infos; } @@ -1552,14 +1549,11 @@ template inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { auto num_inputs = GetInputCount(); std::vector input_devices; - input_devices.reserve(num_inputs); - - const OrtEpDevice* device_ptrs; - ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, &device_ptrs, num_inputs)); + input_devices.resize(num_inputs); - for (size_t i = 0; i < num_inputs; ++i) { - input_devices.emplace_back(device_ptrs[i]); - } + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(&input_devices[0]), + num_inputs)); return input_devices; } From c9adcf8f966347d059ff078ef78328483d62cc49 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Aug 2025 14:35:33 -0700 Subject: [PATCH 06/12] Address build errors, add coverage --- .../core/session/onnxruntime_cxx_inline.h | 15 +++++++++------ onnxruntime/test/shared_lib/test_data_copy.cc | 11 ++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 04bef6d8600d2..33b7b31b28b39 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1499,13 +1499,15 @@ inline std::vector ConstSessionImpl::GetOverridableInitializerNa template inline std::vector ConstSessionImpl::GetMemoryInfoForInputs() const { - AllocatorWithDefaultOptions allocator; + static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*), + "ConstMemoryInfo must be compatible with OrtMemoryInfo*"); auto num_inputs = GetInputCount(); std::vector mem_infos; mem_infos.resize(num_inputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, reinterpret_cast(&mem_infos[0]), + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(&mem_infos[0]), num_inputs)); return mem_infos; @@ -1513,14 +1515,15 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( template inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs() const { - AllocatorWithDefaultOptions allocator; + static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*), + "ConstMemoryInfo must be compatible with OrtMemoryInfo*"); + auto num_outputs = GetOutputCount(); std::vector mem_infos; mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, reinterpret_cast(&mem_infos[0], - num_outputs)); - + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, reinterpret_cast(&mem_infos[0]), + num_outputs)); return mem_infos; } diff --git a/onnxruntime/test/shared_lib/test_data_copy.cc b/onnxruntime/test/shared_lib/test_data_copy.cc index 872671135fc6a..e7d9d7715092b 100644 --- a/onnxruntime/test/shared_lib/test_data_copy.cc +++ b/onnxruntime/test/shared_lib/test_data_copy.cc @@ -82,7 +82,16 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { size_t num_inputs = session.GetInputCount(); // find the input location so we know which inputs can be provided on device. - std::vector input_locations = session.GetMemoryInfoForInputs(); + auto input_locations = session.GetMemoryInfoForInputs(); + ASSERT_EQ(session.GetInputCount(), input_locations.size()); + + // Testing coverage + auto input_ep_devices = session.GetEpDeviceForInputs(); + ASSERT_EQ(session.GetInputCount(), input_ep_devices.size()); + + // This is for testing + auto output_locations = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(session.GetOutputCount(), output_locations.size()); std::vector cpu_tensors; From dad43098fcf38a2a358535a25d935db8db4bdd82 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Aug 2025 14:50:42 -0700 Subject: [PATCH 07/12] Fix function name typo --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- include/onnxruntime/core/session/onnxruntime_cxx_inline.h | 2 +- onnxruntime/test/shared_lib/test_inference.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 8b51b15a02f87..0b83dac25c4c2 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -782,7 +782,7 @@ struct TensorRTProviderOptions : detail::Base { void UpdateWithValue(const char* key, void* value); ///< Wrapper around OrtApi::GetTensorRTProviderOptionsByName - void* GetptionByName(const char* name) const; + void* GetOptionByName(const char* name) const; ///< Wrapper around OrtApi::GetTensorRTProviderOptionsAsString std::string GetTensorRTProviderOptionsAsString() const; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 33b7b31b28b39..c3e801fd438b1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -520,7 +520,7 @@ inline void TensorRTProviderOptions::UpdateWithValue(const char* key, void* valu ThrowOnError(GetApi().UpdateTensorRTProviderOptionsWithValue(p_, key, value)); } -inline void* TensorRTProviderOptions::GetptionByName(const char* name) const { +inline void* TensorRTProviderOptions::GetOptionByName(const char* name) const { void* value = nullptr; ThrowOnError(GetApi().GetTensorRTProviderOptionsByName(p_, name, &value)); return value; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index c1f1fc062faf8..cb4e783125a61 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3705,7 +3705,7 @@ TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { cudaStream_t compute_stream = nullptr; cudaStreamCreate(&compute_stream); trt_options.UpdateWithValue("user_compute_stream", compute_stream); - void* user_compute_stream = trt_options.GetptionByName("user_compute_stream"); + void* user_compute_stream = trt_options.GetOptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); From d0e47ef1c222d19012bf6f1cc9dd98a195b792c3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Aug 2025 15:20:54 -0700 Subject: [PATCH 08/12] Remove unused --- onnxruntime/test/shared_lib/test_session_options.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index de7115ac5189c..399777623d83e 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -57,13 +57,11 @@ TEST(CApiTest, session_options_provider_interface_fail_add_cuda) { Ort::SessionOptions session_options; Ort::CUDAProviderOptions cuda_options; - bool thrown = false; try { session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); ASSERT_TRUE(false) << "Appending CUDA options have thrown exception"; } catch (const Ort::Exception& ex) { ASSERT_THAT(ex.what(), testing::HasSubstr("Failed to load")); - thrown = true; } } #endif // defined(USE_CUDA_PROVIDER_INTERFACE) From 4eb14de6b92f508bbbf2db49f8dbb584153d751c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Aug 2025 15:29:39 -0700 Subject: [PATCH 09/12] Remove stray include introduced by AI --- include/onnxruntime/core/session/onnxruntime_cxx_inline.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index c3e801fd438b1..59ad5f92a097e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -13,7 +13,6 @@ #include #include #include -#include "onnxruntime_cxx_api.h" // Convert OrtStatus to Ort::Status and return // instead of throwing From bd7cc7633a7896316e6d937b665815521f74d084 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 18 Aug 2025 10:56:36 -0700 Subject: [PATCH 10/12] Remove unused var --- onnxruntime/test/shared_lib/test_inference.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index cb4e783125a61..786c0ba713b85 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3910,8 +3910,6 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { cuda_options.Update(cuda_options_map); - auto allocator = Ort::AllocatorWithDefaultOptions(); - std::string s = cuda_options.GetCUDAProviderOptionsAsString(); ASSERT_TRUE(s.find("device_id=0") != std::string::npos); ASSERT_TRUE(s.find("gpu_mem_limit=1024") != std::string::npos); From bf379998b27f01eefb9fa31877eeb1901e84994e Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 21 Aug 2025 15:42:52 -0700 Subject: [PATCH 11/12] Address review comments --- .../core/session/onnxruntime_cxx_api.h | 10 +++++----- .../core/session/onnxruntime_cxx_inline.h | 13 +++++++------ onnxruntime/test/providers/cpu/model_tests.cc | 16 +++------------- .../test/shared_lib/test_session_options.cc | 2 +- 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 0b83dac25c4c2..9cb566ab2d352 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -728,11 +728,11 @@ struct Status : detail::Base { using Base = detail::Base; using Base::Base; - explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used - Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. - explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception - explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception - Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used + explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + explicit Status(const Exception&); ///< Creates status instance out of exception + explicit Status(const std::exception&); ///< Creates status instance out of exception + Status(const char* message, OrtErrorCode code); ///< Creates status instance out of null-terminated string message. std::string GetErrorMessage() const; OrtErrorCode GetErrorCode() const; bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 59ad5f92a097e..c19c43d2a854c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -56,15 +56,15 @@ inline void ThrowOnError(const Status& st) { inline Status::Status(OrtStatus* status) noexcept : detail::Base{status} { } -inline Status::Status(const std::exception& e) noexcept { +inline Status::Status(const std::exception& e) { p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); } -inline Status::Status(const Exception& e) noexcept { +inline Status::Status(const Exception& e) { p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); } -inline Status::Status(const char* message, OrtErrorCode code) noexcept { +inline Status::Status(const char* message, OrtErrorCode code) { p_ = GetApi().CreateStatus(code, message); } @@ -1506,7 +1506,7 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( mem_infos.resize(num_inputs); ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, - reinterpret_cast(&mem_infos[0]), + reinterpret_cast(mem_infos.data()), num_inputs)); return mem_infos; @@ -1521,7 +1521,8 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs std::vector mem_infos; mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, reinterpret_cast(&mem_infos[0]), + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), num_outputs)); return mem_infos; } @@ -1554,7 +1555,7 @@ inline std::vector ConstSessionImpl::GetEpDeviceForInputs() co input_devices.resize(num_inputs); ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, - reinterpret_cast(&input_devices[0]), + reinterpret_cast(input_devices.data()), num_inputs)); return input_devices; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 635c97291d4bd..a17982ecb5eab 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -22,6 +22,7 @@ #include #include "default_providers.h" #include "test/onnx/TestCase.h" +#include "test/util/include/api_asserts.h" #ifdef USE_DNNL #include "core/providers/dnnl/dnnl_provider_factory.h" @@ -59,17 +60,6 @@ extern std::unique_ptr ort_env; -// asserts that the OrtStatus* result of `status_expr` does not indicate an error -// note: this takes ownership of the OrtStatus* result -#define ASSERT_CXX_ORTSTATUS_OK(status_expr) \ - do { \ - if (OrtStatus* _status = (status_expr); _status != nullptr) { \ - std::unique_ptr _rel_status{ \ - _status, &OrtApis::ReleaseStatus}; \ - FAIL() << "OrtStatus error: " << OrtApis::GetErrorMessage(_rel_status.get()); \ - } \ - } while (false) - using namespace onnxruntime::common; namespace onnxruntime { @@ -290,11 +280,11 @@ TEST_P(ModelTest, Run) { std::unordered_map feeds; l->LoadTestData(task_id, holder, feeds, true); size_t output_count; - ASSERT_CXX_ORTSTATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); + ASSERT_ORTSTATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); // Create output feed std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { - ASSERT_CXX_ORTSTATUS_OK( + ASSERT_ORTSTATUS_OK( OrtApis::SessionGetOutputName(ort_session, i, default_allocator.get(), &output_names[i])); } diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index 399777623d83e..d12a586f662ac 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -59,7 +59,7 @@ TEST(CApiTest, session_options_provider_interface_fail_add_cuda) { Ort::CUDAProviderOptions cuda_options; try { session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(false) << "Appending CUDA options have thrown exception"; + FAIL() << "Appending CUDA options have thrown exception"; } catch (const Ort::Exception& ex) { ASSERT_THAT(ex.what(), testing::HasSubstr("Failed to load")); } From fa986795c0a63f59889e4c2d002e73775d643a28 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 21 Aug 2025 16:01:06 -0700 Subject: [PATCH 12/12] build errors --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- onnxruntime/test/util/include/api_asserts.h | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 1dcf8b23df4a8..2f4fd36c8115f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -729,7 +729,7 @@ struct Status : detail::Base { using Base::Base; explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used - explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. explicit Status(const Exception&); ///< Creates status instance out of exception explicit Status(const std::exception&); ///< Creates status instance out of exception Status(const char* message, OrtErrorCode code); ///< Creates status instance out of null-terminated string message. diff --git a/onnxruntime/test/util/include/api_asserts.h b/onnxruntime/test/util/include/api_asserts.h index 946782752e4bd..0be3b8bbb0764 100644 --- a/onnxruntime/test/util/include/api_asserts.h +++ b/onnxruntime/test/util/include/api_asserts.h @@ -12,25 +12,25 @@ // asserts for the public API #define ASSERT_ORTSTATUS_OK(function) \ do { \ - Ort::Status _tmp_status = (function); \ + Ort::Status _tmp_status{(function)}; \ ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) #define EXPECT_ORTSTATUS_OK(api, function) \ do { \ - Ort::Status _tmp_status = (api->function); \ + Ort::Status _tmp_status{(api->function)}; \ EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) #define ASSERT_ORTSTATUS_NOT_OK(api, function) \ do { \ - Ort::Status _tmp_status = (api->function); \ + Ort::Status _tmp_status{(api->function)}; \ ASSERT_TRUE(_tmp_status.IsOK()); \ } while (false) #define EXPECT_ORTSTATUS_NOT_OK(api, function) \ do { \ - Ort::Status _tmp_status = (api->function); \ + Ort::Status _tmp_status{(api->function)}; \ EXPECT_FALSE(_tmp_status.IsOK()); \ } while (false)