diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 9cc1604d71e68..0a2f379dad999 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -53,7 +53,8 @@ jobs: runs-on: macos-15 env: - xcode_version: 16 + xcode_version: 16.4 + simulator_runtime_version: 18.5 strategy: matrix: @@ -90,6 +91,8 @@ jobs: --apple_deploy_target=15.1 \ --apple_sysroot=iphonesimulator \ --osx_arch=${{ matrix.target_arch }} + env: + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: ${{ env.simulator_runtime_version }} Objective-C-StaticAnalysis: runs-on: macos-14 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index f76ad642447ba..2f1532d0643ae 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -336,7 +336,13 @@ if (onnxruntime_ENABLE_CPUINFO) set(CPUINFO_SUPPORTED TRUE) endif() if (WIN32) - set(CPUINFO_SUPPORTED TRUE) + # There's an error when linking with cpuinfo on arm64ec with a vcpkg build (--use_vcpkg). + # TODO Fix it and then re-enable cpuinfo on arm64ec. + if (onnxruntime_target_platform STREQUAL "ARM64EC") + set(CPUINFO_SUPPORTED FALSE) + else() + set(CPUINFO_SUPPORTED TRUE) + endif() elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 5dcc2b2628bf4..d927489372e7c 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -14,7 +14,7 @@ set(onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/platform/check_intel.h" "${ONNXRUNTIME_ROOT}/core/platform/check_intel.cc" "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.h" - "${ONNXRUNTIME_ROOT}/core/platform/device_discovery.cc" + "${ONNXRUNTIME_ROOT}/core/platform/device_discovery_common.cc" "${ONNXRUNTIME_ROOT}/core/platform/env.h" "${ONNXRUNTIME_ROOT}/core/platform/env.cc" "${ONNXRUNTIME_ROOT}/core/platform/env_time.h" @@ -32,18 +32,30 @@ set(onnxruntime_common_src_patterns if(WIN32) list(APPEND onnxruntime_common_src_patterns - "${ONNXRUNTIME_ROOT}/core/platform/windows/*.h" - "${ONNXRUNTIME_ROOT}/core/platform/windows/*.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/debug_alloc.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/debug_alloc.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/dll_load_error.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/dll_load_error.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/env_time.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/env.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/env.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/hardware_core_enumerator.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/hardware_core_enumerator.h" + "${ONNXRUNTIME_ROOT}/core/platform/windows/stacktrace.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/telemetry.cc" + "${ONNXRUNTIME_ROOT}/core/platform/windows/telemetry.h" "${ONNXRUNTIME_ROOT}/core/platform/windows/logging/*.h" "${ONNXRUNTIME_ROOT}/core/platform/windows/logging/*.cc" ) else() list(APPEND onnxruntime_common_src_patterns - "${ONNXRUNTIME_ROOT}/core/platform/posix/*.h" - "${ONNXRUNTIME_ROOT}/core/platform/posix/*.cc" + "${ONNXRUNTIME_ROOT}/core/platform/posix/env_time.cc" + "${ONNXRUNTIME_ROOT}/core/platform/posix/env.cc" + "${ONNXRUNTIME_ROOT}/core/platform/posix/stacktrace.cc" ) + # logging files if (onnxruntime_USE_SYSLOG) list(APPEND onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/platform/posix/logging/*.h" @@ -51,7 +63,7 @@ else() ) endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") + if (ANDROID) list(APPEND onnxruntime_common_src_patterns "${ONNXRUNTIME_ROOT}/core/platform/android/logging/*.h" "${ONNXRUNTIME_ROOT}/core/platform/android/logging/*.cc" @@ -66,6 +78,21 @@ else() endif() endif() +# platform-specific device discovery files +if (WIN32) + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/windows/device_discovery.cc") +elseif (LINUX) + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/linux/device_discovery.cc") +elseif (APPLE) + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/apple/device_discovery.cc") +else() + list(APPEND onnxruntime_common_src_patterns + "${ONNXRUNTIME_ROOT}/core/platform/device_discovery_default.cc") +endif() + if(onnxruntime_target_platform STREQUAL "ARM64EC") if (MSVC) link_directories("$ENV{VCINSTALLDIR}/Tools/MSVC/$ENV{VCToolsVersion}/lib/ARM64EC") @@ -216,8 +243,6 @@ endif() if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) # Link cpuinfo if supported - # Using it mainly in ARM with Android. - # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. if (CPUINFO_SUPPORTED) onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) diff --git a/cmake/vcpkg-ports/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch b/cmake/vcpkg-ports/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch new file mode 100644 index 0000000000000..23ceeb8f758cc --- /dev/null +++ b/cmake/vcpkg-ports/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch @@ -0,0 +1,22 @@ +diff --git a/include/cpuinfo.h b/include/cpuinfo.h +index f1d35d4..9e454d2 100644 +--- a/include/cpuinfo.h ++++ b/include/cpuinfo.h +@@ -18,7 +18,7 @@ + #define CPUINFO_ARCH_X86 1 + #endif + +-#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) ++#if defined(__x86_64__) || defined(__x86_64) || (defined(_M_X64) && !defined(_M_ARM64EC)) || (defined(_M_AMD64) && !defined(_M_ARM64EC)) + #define CPUINFO_ARCH_X86_64 1 + #endif + +@@ -26,7 +26,7 @@ + #define CPUINFO_ARCH_ARM 1 + #endif + +-#if defined(__aarch64__) || defined(_M_ARM64) ++#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC) + #define CPUINFO_ARCH_ARM64 1 + #endif + diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index e61308bf643b4..917fd29a8d28b 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -9,6 +9,8 @@ vcpkg_from_github( REF 8a1772a0c5c447df2d18edf33ec4603a8c9c04a6 SHA512 b94ccbfa886221d6bb16513d074675af0a72928a9dd9485dcacdc1124a8a60aacbbe91913a1579e766dfb024f0be1d52eeead40342004ff0238a8b94a095ed08 HEAD_REF master + PATCHES + patch_cpuinfo_h_for_arm64ec.patch ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/include/onnxruntime/core/common/parse_string.h b/include/onnxruntime/core/common/parse_string.h index 6345b2a55490d..5f88d490b3415 100644 --- a/include/onnxruntime/core/common/parse_string.h +++ b/include/onnxruntime/core/common/parse_string.h @@ -35,13 +35,30 @@ template std::enable_if_t, bool> TryParseStringWithClassicLocale(std::string_view str, T& value) { T parsed_value{}; - const auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), parsed_value); - if (ec != std::errc{}) { + std::from_chars_result conversion_result{}; + if constexpr (std::is_integral_v && std::is_unsigned_v) { + // For unsigned integral types, also handle hex values, i.e., those beginning with "0x". + // std::from_chars() does not accept the "0x" prefix. + const bool has_hex_prefix = str.size() >= 2 && + str[0] == '0' && + (str[1] == 'x' || str[1] == 'X'); + + if (has_hex_prefix) { + str = str.substr(2); + } + + const int base = has_hex_prefix ? 16 : 10; + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value, base); + } else { + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value); + } + + if (conversion_result.ec != std::errc{}) { return false; } - if (ptr != str.data() + str.size()) { + if (conversion_result.ptr != str.data() + str.size()) { return false; } diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 1bb7f219c9a45..f54f4a5a6f1ef 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -36,6 +36,7 @@ class GraphOptimizerRegistry; #include "core/framework/framework_provider_common.h" #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +#include "core/session/onnxruntime_c_api.h" struct OrtEpDevice; struct OrtRunOptions; @@ -322,6 +323,29 @@ class IExecutionProvider { virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs); + /** + * Get the compatibility info for a compiled model. + * + * The execution provider determines this value, which denotes the compatibility of the compiled model with the EP. + * This is stored in the model metadata under a key associated with the EP type. + */ + virtual std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const { + // graph_viewer and model_metadata are not used in the default implementation. + ORT_UNUSED_PARAMETER(graph_viewer); + // Default implementation returns empty string + return std::string(); + } + + /** + * Validate the compatibility of a compiled model with this execution provider. + */ + virtual common::Status ValidateCompiledModelCompatibilityInfo(const std::string& /*compatibility_info*/, + OrtCompiledModelCompatibility& model_compatibility) const { + // Default implementation indicates this EP does not support model compatibility validation + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return Status::OK(); + } + #endif void SetLogger(const logging::Logger* logger) { diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index 088db79a7e005..8ef4fdb66e1e6 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -31,7 +31,7 @@ struct IndexedSubGraph { std::string domain; ///< Domain of customized SubGraph/FunctionProto int since_version; ///< Since version of customized SubGraph/FunctionProto. - ONNX_NAMESPACE::OperatorStatus status; ///< Status of customized SubGraph/FunctionProto. + ONNX_NAMESPACE::OperatorStatus status{ONNX_NAMESPACE::OperatorStatus::STABLE}; ///< Status of customized SubGraph/FunctionProto. std::vector inputs; ///< Inputs of customized SubGraph/FunctionProto. std::vector outputs; ///< Outputs of customized SubGraph/FunctionProto. diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index 11cc6f131dab3..dc27204017caa 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -32,9 +32,8 @@ constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable"; -constexpr const char* kONNXBytestream = "nv_onnx_bytestream"; -constexpr const char* kONNXBytestreamSize = "nv_onnx_bytestream_size"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; +constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; } // namespace provider_option_names namespace run_option_names { diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 89467f5238fa9..59ca1a1df762e 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -199,23 +199,6 @@ class Environment { using OrtAllocatorUniquePtr = std::unique_ptr>; - // if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with - // OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator. - // we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in - // shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an - // OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently. - // we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_. - // - // TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator - // or an OrtAllocator instance to reduce the indirection a little. - // with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the - // IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_. - // - // Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena - // implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a - // cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source. - std::unordered_map> arena_ort_allocators_; - #if !defined(ORT_MINIMAL_BUILD) // register EPs that are built into the ORT binary so they can take part in AutoEP selection // added to ep_libraries diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6eb15280a4aa4..bedeeb972c3a7 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5829,7 +5829,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph, + ORT_API2_STATUS(Graph_GetNodes, _In_ const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); /** \brief Get the parent node for the given graph, if any exists. @@ -6469,6 +6469,17 @@ struct OrtApi { _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + + /** \brief Get ::OrtModelMetadata from an ::OrtGraph + * + * \param[in] graph The OrtGraph instance. + * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d1b08f127fa2a..2f4fd36c8115f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -566,6 +566,7 @@ ORT_DEFINE_RELEASE(ModelMetadata); ORT_DEFINE_RELEASE(IoBinding); ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(Status); +ORT_DEFINE_RELEASE(SyncStream); ORT_DEFINE_RELEASE(OpAttr); ORT_DEFINE_RELEASE(Op); ORT_DEFINE_RELEASE(KernelInfo); @@ -573,10 +574,16 @@ ORT_DEFINE_RELEASE(ValueInfo); ORT_DEFINE_RELEASE(Node); ORT_DEFINE_RELEASE(Graph); ORT_DEFINE_RELEASE(Model); -ORT_DEFINE_RELEASE(KeyValuePairs) +ORT_DEFINE_RELEASE(KeyValuePairs); +ORT_DEFINE_RELEASE(PrepackedWeightsContainer); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); +// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, +// but the struct has V2 in its name to indicate that it is the second version of the options. +inline void OrtRelease(OrtTensorRTProviderOptionsV2* ptr) { GetApi().ReleaseTensorRTProviderOptions(ptr); } +inline void OrtRelease(OrtCUDAProviderOptionsV2* ptr) { GetApi().ReleaseCUDAProviderOptions(ptr); } + #undef ORT_DEFINE_RELEASE #undef ORT_DEFINE_RELEASE_FROM_API_STRUCT @@ -628,6 +635,7 @@ struct Base { } constexpr operator contained_type*() const noexcept { return p_; } + constexpr contained_type& operator*() const noexcept { return *p_; } /// \brief Relinquishes ownership of the contained C object pointer /// The underlying object is not destroyed @@ -672,6 +680,7 @@ struct Base> { } constexpr operator contained_type*() const noexcept { return p_; } + constexpr contained_type& operator*() const noexcept { return *p_; } protected: contained_type* p_{}; @@ -697,6 +706,11 @@ struct Model; struct Node; struct ModelMetadata; struct TypeInfo; +struct PrepackedWeightsContainer; +struct Session; +struct SessionOptions; +struct SyncStream; +struct TensorRTProviderOptions; struct Value; struct ValueInfo; @@ -714,11 +728,11 @@ struct Status : detail::Base { using Base = detail::Base; using Base::Base; - explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used - explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. - explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception - explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception - Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. + explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used + explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. + explicit Status(const Exception&); ///< Creates status instance out of exception + explicit Status(const std::exception&); ///< Creates status instance out of exception + Status(const char* message, OrtErrorCode code); ///< Creates status instance out of null-terminated string message. std::string GetErrorMessage() const; OrtErrorCode GetErrorCode() const; bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. @@ -754,6 +768,58 @@ struct ThreadingOptions : detail::Base { ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); }; +/** \brief The TensorRTOptions (V2) + * + * Used to pass options to TRT EP + */ +struct TensorRTProviderOptions : detail::Base { + TensorRTProviderOptions(std::nullptr_t) {} + /// \brief Wraps OrtApi::CreateTensorRTProviderOptionsV2 + TensorRTProviderOptions(); + ///< Wrapper around OrtApi::UpdateTensorRTProviderOptions + void Update(const std::unordered_map& options); + ///< Wrapper around OrtApi::UpdateTensorRTProviderOptions + void UpdateWithValue(const char* key, void* value); + + ///< Wrapper around OrtApi::GetTensorRTProviderOptionsByName + void* GetOptionByName(const char* name) const; + ///< Wrapper around OrtApi::GetTensorRTProviderOptionsAsString + std::string GetTensorRTProviderOptionsAsString() const; +}; + +/** \brief The CUDAProviderOptions (V2) + * + * Used to pass options to CUDA EP + */ +struct CUDAProviderOptions : detail::Base { + CUDAProviderOptions(std::nullptr_t) {} + /// \brief Wraps OrtApi::CreateCUDAProviderOptions + CUDAProviderOptions(); + ///< Wrapper around OrtApi::UpdateCUDAProviderOptions + void Update(const std::unordered_map& options); + ///< Wrapper around OrtApi::GetCUDAProviderOptionsAsString + std::string GetCUDAProviderOptionsAsString() const; + ///< Wrapper around OrtApi::UpdateCUDAProviderOptionsWithValue + void UpdateWithValue(const char* key, void* value); + ///< Wrapper around OrtApi::GetCUDAProviderOptionsByName + void* GetOptionByName(const char* name) const; +}; + +/** \brief The PrepackedWeightsContainer + * + * Create only and pass to Ort::Session constructor for multiple sessions + * to share pre-packed weights. + */ +struct PrepackedWeightsContainer : detail::Base { + using Base = detail::Base; + ///< No instance is created + explicit PrepackedWeightsContainer(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit PrepackedWeightsContainer(OrtPrepackedWeightsContainer* p) : Base{p} {} + /// \brief Wraps OrtApi::CreatePrepackedWeightsContainer + PrepackedWeightsContainer(); +}; + namespace detail { template struct KeyValuePairsImpl : Ort::detail::Base { @@ -793,6 +859,111 @@ struct KeyValuePairs : detail::KeyValuePairsImpl { ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; } }; +namespace detail { +template +struct MemoryInfoImpl : Base { + using B = Base; + using B::B; + + std::string GetAllocatorName() const; ///< Wrapper MemoryInfoGetName + OrtAllocatorType GetAllocatorType() const; ///< Wrapper MemoryInfoGetType + int GetDeviceId() const; ///< Wrapper MemoryInfoGetId + OrtMemoryInfoDeviceType GetDeviceType() const; ///< Wrapper MemoryInfoGetDeviceType + OrtMemType GetMemoryType() const; ///< Wrapper MemoryInfoGetMemType + OrtDeviceMemoryType GetDeviceMemoryType() const; ///< Wrapper MemoryInfoGetDeviceMemType + uint32_t GetVendorId() const; ///< Wrapper MemoryInfoGetVendorId + + template + bool operator==(const MemoryInfoImpl& o) const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstMemoryInfo = detail::MemoryInfoImpl>; + +/** \brief Wrapper around ::OrtMemoryInfo + * + */ +struct MemoryInfo : detail::MemoryInfoImpl { + static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); + explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API + MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); + MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, + OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type); ///< Wrapper around CreateMemoryInfo_V2 + ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } +}; + +/// +/// Represents native memory allocation coming from one of the +/// OrtAllocators registered with OnnxRuntime. +/// Use it to wrap an allocation made by an allocator +/// so it can be automatically released when no longer needed. +/// +struct MemoryAllocation { + MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); + ~MemoryAllocation(); + MemoryAllocation(const MemoryAllocation&) = delete; + MemoryAllocation& operator=(const MemoryAllocation&) = delete; + MemoryAllocation(MemoryAllocation&&) noexcept; + MemoryAllocation& operator=(MemoryAllocation&&) noexcept; + + void* get() { return p_; } + size_t size() const { return size_; } + + private: + OrtAllocator* allocator_; + void* p_; + size_t size_; +}; + +namespace detail { +template +struct AllocatorImpl : Base { + using B = Base; + using B::B; + + void* Alloc(size_t size); + MemoryAllocation GetAllocation(size_t size); + void Free(void* p); + ConstMemoryInfo GetInfo() const; + + /** \brief Function that returns the statistics of the allocator. + * + * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics. + */ + KeyValuePairs GetStats() const; +}; +} // namespace detail + +/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime + * + */ +struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { + explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + AllocatorWithDefaultOptions(); +}; + +/** \brief Wrapper around ::OrtAllocator + * + */ + +struct Allocator : detail::AllocatorImpl { + explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance + Allocator(const Session& session, const OrtMemoryInfo*); +}; + +using UnownedAllocator = detail::AllocatorImpl>; + +/** \brief Wrapper around ::OrtSyncStream + * + */ +struct SyncStream : detail::Base { + explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API + void* GetHandle() const; ///< Wraps SyncStream_GetHandle +}; + namespace detail { template struct HardwareDeviceImpl : Ort::detail::Base { @@ -823,6 +994,8 @@ struct EpDeviceImpl : Ort::detail::Base { ConstKeyValuePairs EpMetadata() const; ConstKeyValuePairs EpOptions() const; ConstHardwareDevice Device() const; + ConstMemoryInfo GetMemoryInfo(OrtDeviceMemoryType memory_type) const; ///< Wraps EpDevice_MemoryInfo + SyncStream CreateSyncStream(ConstKeyValuePairs stream_options = {}) const; /// Wraps EpDevice_CreateSyncStream }; } // namespace detail @@ -877,10 +1050,28 @@ struct Env : detail::Base { const std::unordered_map& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 + Env& RegisterAllocator(OrtAllocator* allocator); ///< Wraps OrtApi::RegisterAllocator + + Env& UnregisterAllocator(const OrtMemoryInfo* mem_info); ///< Wraps OrtApi::UnregisterAllocator + + UnownedAllocator CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options); ///< Wraps OrtApi::CreateSharedAllocator + + // Result may be nullptr + UnownedAllocator GetSharedAllocator(const OrtMemoryInfo* mem_info); ///< Wraps OrtApi::GetSharedAllocator + + void ReleaseSharedAllocator(const OrtEpDevice* ep_device, + OrtDeviceMemoryType mem_type); ///< Wraps OrtApi::ReleaseSharedAllocator + Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary std::vector GetEpDevices() const; + + Status CopyTensors(const std::vector& src_tensors, + const std::vector& dst_tensors, + OrtSyncStream* stream) const; ///< Wraps OrtApi::CopyTensors }; /** \brief Custom Op Domain @@ -1018,8 +1209,6 @@ struct CustomOpConfigs { * Wraps ::OrtSessionOptions object and methods */ -struct SessionOptions; - namespace detail { // we separate const-only methods because passing const ptr to non-const methods // is only discovered when inline methods are compiled which is counter-intuitive @@ -1077,6 +1266,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { const std::vector& external_initializer_file_buffer_array, const std::vector& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory + SessionOptionsImpl& AppendExecutionProvider_CPU(int use_arena); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CPU SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM @@ -1264,6 +1454,10 @@ struct ConstSessionImpl : Base { std::vector GetOutputNames() const; std::vector GetOverridableInitializerNames() const; + std::vector GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs + std::vector GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs + std::vector GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs + /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1427,37 +1621,6 @@ struct Session : detail::SessionImpl { UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } }; -namespace detail { -template -struct MemoryInfoImpl : Base { - using B = Base; - using B::B; - - std::string GetAllocatorName() const; - OrtAllocatorType GetAllocatorType() const; - int GetDeviceId() const; - OrtMemoryInfoDeviceType GetDeviceType() const; - OrtMemType GetMemoryType() const; - - template - bool operator==(const MemoryInfoImpl& o) const; -}; -} // namespace detail - -// Const object holder that does not own the underlying object -using ConstMemoryInfo = detail::MemoryInfoImpl>; - -/** \brief Wrapper around ::OrtMemoryInfo - * - */ -struct MemoryInfo : detail::MemoryInfoImpl { - static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); - explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API - MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); - ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } -}; - namespace detail { template struct TensorTypeAndShapeInfoImpl : Base { @@ -1686,7 +1849,7 @@ struct ConstValueImpl : Base { /// /// const pointer to data, no copies made template - const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// + const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorData /// /// /// Returns a non-typed pointer to a tensor contained data. @@ -1956,7 +2119,7 @@ struct Value : detail::ValueImpl { using OrtSparseValuesParam = detail::OrtSparseValuesParam; using Shape = detail::Shape; - explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used + Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used Value(Value&&) = default; Value& operator=(Value&&) = default; @@ -2121,67 +2284,6 @@ struct Value : detail::ValueImpl { #endif // !defined(DISABLE_SPARSE_TENSORS) }; -/// -/// Represents native memory allocation coming from one of the -/// OrtAllocators registered with OnnxRuntime. -/// Use it to wrap an allocation made by an allocator -/// so it can be automatically released when no longer needed. -/// -struct MemoryAllocation { - MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); - ~MemoryAllocation(); - MemoryAllocation(const MemoryAllocation&) = delete; - MemoryAllocation& operator=(const MemoryAllocation&) = delete; - MemoryAllocation(MemoryAllocation&&) noexcept; - MemoryAllocation& operator=(MemoryAllocation&&) noexcept; - - void* get() { return p_; } - size_t size() const { return size_; } - - private: - OrtAllocator* allocator_; - void* p_; - size_t size_; -}; - -namespace detail { -template -struct AllocatorImpl : Base { - using B = Base; - using B::B; - - void* Alloc(size_t size); - MemoryAllocation GetAllocation(size_t size); - void Free(void* p); - ConstMemoryInfo GetInfo() const; - - /** \brief Function that returns the statistics of the allocator. - * - * \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics. - */ - KeyValuePairs GetStats() const; -}; - -} // namespace detail - -/** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime - * - */ -struct AllocatorWithDefaultOptions : detail::AllocatorImpl> { - explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - AllocatorWithDefaultOptions(); -}; - -/** \brief Wrapper around ::OrtAllocator - * - */ -struct Allocator : detail::AllocatorImpl { - explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance - Allocator(const Session& session, const OrtMemoryInfo*); -}; - -using UnownedAllocator = detail::AllocatorImpl>; - namespace detail { namespace binding_utils { // Bring these out of template @@ -2244,6 +2346,12 @@ struct ArenaCfg : detail::Base { * See docs/C_API.md for details on what the following parameters mean and how to choose these values */ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); + + /** + * Wraps Ort::CreateArenaCfgV2 + * See C API for details on what the following parameters mean and how to choose these values + */ + explicit ArenaCfg(const std::unordered_map& arena_config); }; // @@ -2834,6 +2942,7 @@ struct GraphImpl : Ort::detail::Base { void SetOutputs(std::vector& outputs); void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value void AddNode(Node& node); // Graph takes ownership of Node + ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::Graph_GetModelMetadata #endif // !defined(ORT_MINIMAL_BUILD) }; } // namespace detail @@ -2848,6 +2957,7 @@ struct Graph : detail::GraphImpl { Graph(); #endif }; +using ConstGraph = detail::GraphImpl>; namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 705f17c5d6f43..73200d8852223 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -56,15 +56,15 @@ inline void ThrowOnError(const Status& st) { inline Status::Status(OrtStatus* status) noexcept : detail::Base{status} { } -inline Status::Status(const std::exception& e) noexcept { +inline Status::Status(const std::exception& e) { p_ = GetApi().CreateStatus(ORT_FAIL, e.what()); } -inline Status::Status(const Exception& e) noexcept { +inline Status::Status(const Exception& e) { p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); } -inline Status::Status(const char* message, OrtErrorCode code) noexcept { +inline Status::Status(const char* message, OrtErrorCode code) { p_ = GetApi().CreateStatus(code, message); } @@ -296,6 +296,16 @@ inline OrtMemType MemoryInfoImpl::GetMemoryType() const { return type; } +template +inline OrtDeviceMemoryType MemoryInfoImpl::GetDeviceMemoryType() const { + return GetApi().MemoryInfoGetDeviceMemType(this->p_); +} + +template +inline uint32_t MemoryInfoImpl::GetVendorId() const { + return GetApi().MemoryInfoGetVendorId(this->p_); +} + template template inline bool MemoryInfoImpl::operator==(const MemoryInfoImpl& o) const { @@ -316,6 +326,12 @@ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, O ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_)); } +inline MemoryInfo::MemoryInfo(const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, uint32_t device_id, + OrtDeviceMemoryType mem_type, size_t alignment, OrtAllocatorType allocator_type) { + ThrowOnError(GetApi().CreateMemoryInfo_V2(name, device_type, vendor_id, device_id, mem_type, alignment, + allocator_type, &this->p_)); +} + namespace detail { template inline std::vector ConstIoBindingImpl::GetOutputNames() const { @@ -404,20 +420,7 @@ inline std::vector GetOutputNamesHelper(const OrtIoBinding* binding inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) { std::vector result; - size_t owned = 0; size_t output_count = 0; - // Lambda to release the buffer when no longer needed and - // make sure that we destroy all instances on exception - auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) { - if (buffer) { - while (owned < output_count) { - auto* p = buffer + owned++; - GetApi().ReleaseValue(*p); - } - allocator->Free(allocator, buffer); - } - }; - using Ptr = std::unique_ptr; OrtValue** output_buffer = nullptr; ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count)); @@ -425,12 +428,11 @@ inline std::vector GetOutputValuesHelper(const OrtIoBinding* binding, Ort return result; } - Ptr buffer_g(output_buffer, free_fn); + std::unique_ptr buffer_g(output_buffer, AllocatedFree(allocator)); result.reserve(output_count); for (size_t i = 0; i < output_count; ++i) { result.emplace_back(output_buffer[i]); - ++owned; } return result; } @@ -446,6 +448,18 @@ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_)); } +inline ArenaCfg::ArenaCfg(const std::unordered_map& arena_config) { + std::vector keys; + std::vector values; + keys.reserve(arena_config.size()); + values.reserve(arena_config.size()); + for (const auto& kv : arena_config) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second); + } + ThrowOnError(GetApi().CreateArenaCfgV2(keys.data(), values.data(), arena_config.size(), &p_)); +} + inline ThreadingOptions::ThreadingOptions() { ThrowOnError(GetApi().CreateThreadingOptions(&p_)); } @@ -485,6 +499,78 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom return *this; } +inline TensorRTProviderOptions::TensorRTProviderOptions() { + ThrowOnError(GetApi().CreateTensorRTProviderOptions(&this->p_)); +} + +inline void TensorRTProviderOptions::Update(const std::unordered_map& options) { + std::vector keys; + std::vector values; + keys.reserve(options.size()); + values.reserve(options.size()); + for (const auto& kv : options) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second.c_str()); + } + ThrowOnError(GetApi().UpdateTensorRTProviderOptions(p_, keys.data(), values.data(), options.size())); +} + +inline void TensorRTProviderOptions::UpdateWithValue(const char* key, void* value) { + ThrowOnError(GetApi().UpdateTensorRTProviderOptionsWithValue(p_, key, value)); +} + +inline void* TensorRTProviderOptions::GetOptionByName(const char* name) const { + void* value = nullptr; + ThrowOnError(GetApi().GetTensorRTProviderOptionsByName(p_, name, &value)); + return value; +} + +inline std::string TensorRTProviderOptions::GetTensorRTProviderOptionsAsString() const { + AllocatorWithDefaultOptions allocator; + char* options_str = nullptr; + ThrowOnError(GetApi().GetTensorRTProviderOptionsAsString(p_, allocator, &options_str)); + std::unique_ptr options_str_g(options_str, detail::AllocatedFree(allocator)); + return std::string(options_str); +} + +inline CUDAProviderOptions::CUDAProviderOptions() { + ThrowOnError(GetApi().CreateCUDAProviderOptions(&this->p_)); +} + +inline void CUDAProviderOptions::Update(const std::unordered_map& options) { + std::vector keys; + std::vector values; + keys.reserve(options.size()); + values.reserve(options.size()); + for (const auto& kv : options) { + keys.push_back(kv.first.c_str()); + values.push_back(kv.second.c_str()); + } + ThrowOnError(GetApi().UpdateCUDAProviderOptions(p_, keys.data(), values.data(), options.size())); +} + +inline std::string CUDAProviderOptions::GetCUDAProviderOptionsAsString() const { + AllocatorWithDefaultOptions allocator; + char* options_str = nullptr; + ThrowOnError(GetApi().GetCUDAProviderOptionsAsString(p_, allocator, &options_str)); + std::unique_ptr options_str_g(options_str, detail::AllocatedFree(allocator)); + return std::string(options_str); +} + +inline void CUDAProviderOptions::UpdateWithValue(const char* key, void* value) { + ThrowOnError(GetApi().UpdateCUDAProviderOptionsWithValue(p_, key, value)); +} + +inline void* CUDAProviderOptions::GetOptionByName(const char* name) const { + void* value = nullptr; + ThrowOnError(GetApi().GetCUDAProviderOptionsByName(p_, name, &value)); + return value; +} + +inline PrepackedWeightsContainer::PrepackedWeightsContainer() { + ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_)); +} + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { @@ -547,6 +633,10 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } +inline void* SyncStream::GetHandle() const { + return GetApi().SyncStream_GetHandle(this->p_); +} + namespace detail { template inline OrtHardwareDeviceType HardwareDeviceImpl::Type() const { @@ -597,6 +687,19 @@ template inline ConstHardwareDevice EpDeviceImpl::Device() const { return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_)); } + +template +inline ConstMemoryInfo EpDeviceImpl::GetMemoryInfo(OrtDeviceMemoryType memory_type) const { + const auto* mem_info = GetApi().EpDevice_MemoryInfo(this->p_, memory_type); + return ConstMemoryInfo{mem_info}; +} + +template +inline SyncStream EpDeviceImpl::CreateSyncStream(ConstKeyValuePairs stream_options) const { + OrtSyncStream* stream = nullptr; + ThrowOnError(GetApi().CreateSyncStreamForEpDevice(this->p_, stream_options, &stream)); + return SyncStream{stream}; +} } // namespace detail inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, @@ -676,6 +779,16 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, return *this; } +inline Env& Env::RegisterAllocator(OrtAllocator* allocator) { + ThrowOnError(GetApi().RegisterAllocator(p_, allocator)); + return *this; +} + +inline Env& Env::UnregisterAllocator(const OrtMemoryInfo* mem_info) { + ThrowOnError(GetApi().UnregisterAllocator(p_, mem_info)); + return *this; +} + inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path) { ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str())); @@ -703,6 +816,41 @@ inline std::vector Env::GetEpDevices() const { return devices; } +inline Status Env::CopyTensors(const std::vector& src_tensors, + const std::vector& dst_tensors, + OrtSyncStream* stream) const { + if (src_tensors.size() != dst_tensors.size()) { + return Status("Source and destination tensor vectors must have the same size", ORT_INVALID_ARGUMENT); + } + if (src_tensors.empty()) { + return Status(); + } + + const OrtValue* const* src_tensors_ptr = reinterpret_cast(src_tensors.data()); + OrtValue* const* dst_tensors_ptr = reinterpret_cast(dst_tensors.data()); + OrtStatus* status = GetApi().CopyTensors(p_, src_tensors_ptr, dst_tensors_ptr, stream, src_tensors.size()); + return Status(status); +} + +inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type, + OrtAllocatorType allocator_type, + const OrtKeyValuePairs* allocator_options) { + OrtAllocator* p; + ThrowOnError(GetApi().CreateSharedAllocator(p_, ep_device, mem_type, allocator_type, allocator_options, &p)); + return UnownedAllocator{p}; +} + +inline UnownedAllocator Env::GetSharedAllocator(const OrtMemoryInfo* mem_info) { + OrtAllocator* p; + ThrowOnError(GetApi().GetSharedAllocator(p_, mem_info, &p)); + return UnownedAllocator{p}; +} + +inline void Env::ReleaseSharedAllocator(const OrtEpDevice* ep_device, + OrtDeviceMemoryType mem_type) { + ThrowOnError(GetApi().ReleaseSharedAllocator(p_, ep_device, mem_type)); +} + inline CustomOpDomain::CustomOpDomain(const char* domain) { ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_)); } @@ -1056,6 +1204,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::AddExternalInitializersFrom return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CPU(int use_arena) { + ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(this->p_, use_arena)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options)); @@ -1298,9 +1452,9 @@ inline std::vector ConstSessionImpl::GetInputNames() const { input_names.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); - input_names.push_back(name); + input_names.emplace_back(name); allocator.Free(name); } @@ -1316,9 +1470,9 @@ inline std::vector ConstSessionImpl::GetOutputNames() const { output_names.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); - output_names.push_back(name); + output_names.emplace_back(name); allocator.Free(name); } @@ -1334,14 +1488,45 @@ inline std::vector ConstSessionImpl::GetOverridableInitializerNa initializer_names.reserve(num_initializers); for (size_t i = 0; i < num_initializers; ++i) { - char* name = nullptr; + char* name; ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); - initializer_names.push_back(name); + initializer_names.emplace_back(name); } return initializer_names; } +template +inline std::vector ConstSessionImpl::GetMemoryInfoForInputs() const { + static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*), + "ConstMemoryInfo must be compatible with OrtMemoryInfo*"); + + auto num_inputs = GetInputCount(); + std::vector mem_infos; + mem_infos.resize(num_inputs); + + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + + return mem_infos; +} + +template +inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs() const { + static_assert(sizeof(ConstMemoryInfo) == sizeof(OrtMemoryInfo*), + "ConstMemoryInfo must be compatible with OrtMemoryInfo*"); + + auto num_outputs = GetOutputCount(); + std::vector mem_infos; + mem_infos.resize(num_outputs); + + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + return mem_infos; +} + template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1363,6 +1548,19 @@ inline AllocatedStringPtr ConstSessionImpl::GetOverridableInitializerNameAllo return AllocatedStringPtr(out, detail::AllocatedFree(allocator)); } +template +inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { + auto num_inputs = GetInputCount(); + std::vector input_devices; + input_devices.resize(num_inputs); + + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + + return input_devices; +} + template inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { uint64_t out; @@ -1857,15 +2055,15 @@ inline size_t ConstValueImpl::GetTensorSizeInBytes() const { template template inline const R* ConstValueImpl::GetTensorData() const { - R* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), (void**)&out)); + const R* out; + ThrowOnError(GetApi().GetTensorData(this->p_, reinterpret_cast(&out))); return out; } template inline const void* ConstValueImpl::GetTensorRawData() const { - void* out; - ThrowOnError(GetApi().GetTensorMutableData(const_cast(this->p_), &out)); + const void* out; + ThrowOnError(GetApi().GetTensorData(this->p_, &out)); return out; } @@ -2798,6 +2996,13 @@ inline void GraphImpl::AddNode(Node& node) { ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); } +template +inline ModelMetadata GraphImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + template <> inline void ModelImpl::AddGraph(Graph& graph) { // Model takes ownership of `graph` diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index f0992f05f31e5..672103bedc437 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -8,3 +8,8 @@ // Key for the execution provider version string. This should be available for all plugin EPs. static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; + +// Prefix for execution provider compatibility information stored in model metadata. +// Used when generating EP context models to store compatibility strings for each EP. +// Full key format: "ep_compatibility_info." +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 314cf76cc8044..7eb5f7659a365 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -382,8 +382,8 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME // Meant to be used with SetEpDynamicOptions // Specify the type of workload for this session. -// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] -// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +// "Default": OS determines the scheduling priority and processor performance to service this workload. [Default] +// "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; // Disables model compilation during session initialization. @@ -401,3 +401,10 @@ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload // - "0": EP compile is not disabled. [DEFAULT] // - "1": EP compile is disabled. static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile"; + +// Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION. +// "0": Allow execution with suboptimal performance. [DEFAULT] +// "1": Fail session creation to require recompilation for optimal performance. +// Note: UNSUPPORTED models always fail regardless of this setting. +static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = + "session.fail_on_suboptimal_compiled_model"; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index dccfdbda8971b..6c66047b4b36a 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/common/cpuid_info.h" + +#include +#include + #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" #include "core/platform/check_intel.h" @@ -51,6 +55,14 @@ #endif // _WIN32 +#if defined(__APPLE__) +#if defined(CPUIDINFO_ARCH_ARM) + +#include + +#endif // defined(CPUIDINFO_ARCH_ARM) +#endif // defined(__APPLE__) + #if defined(CPUINFO_SUPPORTED) #include #if defined(CPUIDINFO_ARCH_ARM) @@ -74,6 +86,14 @@ void decodeMIDR(uint32_t midr, uint32_t uarch[1]); namespace onnxruntime { +void CPUIDInfo::LogEarlyWarning(std::string_view message) { + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(WARNING) << message; + } else { + std::cerr << "onnxruntime cpuid_info warning: " << message << "\n"; + } +} + #if defined(CPUIDINFO_ARCH_X86) static inline void GetCPUID(int function_id, int data[4]) { // NOLINT @@ -108,9 +128,6 @@ void CPUIDInfo::X86Init() { int data[4] = {-1}; GetCPUID(0, data); - vendor_ = GetX86Vendor(data); - vendor_id_ = GetVendorId(vendor_); - int num_IDs = data[0]; if (num_IDs >= 1) { GetCPUID(1, data); @@ -158,24 +175,8 @@ void CPUIDInfo::X86Init() { } } -std::string CPUIDInfo::GetX86Vendor(int32_t* data) { - char vendor[sizeof(int32_t) * 3 + 1]{}; - *reinterpret_cast(vendor + 0) = data[1]; - *reinterpret_cast(vendor + 4) = data[3]; - *reinterpret_cast(vendor + 8) = data[2]; - return vendor; -} - #endif // defined(CPUIDINFO_ARCH_X86) -uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { - if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "AuthenticAMD") return 0x1022; - if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); - if (vendor.find("NV") == 0) return 0x10DE; - return 0; -} - #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -228,10 +229,6 @@ void CPUIDInfo::ArmLinuxInit() { #elif defined(_WIN32) // ^ defined(__linux__) void CPUIDInfo::ArmWindowsInit() { - // Get the ARM vendor string from the registry - vendor_ = GetArmWindowsVendor(); - vendor_id_ = GetVendorId(vendor_); - // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry // There should be one per CPU std::vector midr_values{}, id_aa64isar1_el1_values{}; @@ -323,15 +320,6 @@ void CPUIDInfo::ArmWindowsInit() { #endif // defined(CPUINFO_SUPPORTED) } -std::string CPUIDInfo::GetArmWindowsVendor() { - const int MAX_VALUE_NAME = 256; - const CHAR vendorKey[] = "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"; - CHAR vendorVal[MAX_VALUE_NAME] = ""; - unsigned long vendorSize = sizeof(char) * MAX_VALUE_NAME; - ::RegGetValueA(HKEY_LOCAL_MACHINE, vendorKey, "Vendor Identifier", RRF_RT_REG_SZ | RRF_ZEROONFAILURE, nullptr, &vendorVal, &vendorSize); - return vendorVal; -} - #elif defined(__APPLE__) // ^ defined(_WIN32) void CPUIDInfo::ArmAppleInit() { @@ -376,16 +364,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { } CPUIDInfo::CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) #if defined(CPUINFO_SUPPORTED) pytorch_cpuinfo_init_ = cpuinfo_initialize(); if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation " - "due to undetected CPU features."; + LogEarlyWarning( + "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation due to undetected CPU " + "features."); } #endif // defined(CPUINFO_SUPPORTED) + + // Note: This should be run after cpuinfo initialization if cpuinfo is enabled. + VendorInfoInit(); + +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) ArmLinuxInit(); #elif defined(_WIN32) diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 84571fa12e6ea..d49eca7e1d60c 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -103,7 +103,40 @@ class CPUIDInfo { } private: + // Log function that uses ORT logging if available or writes to stderr. + // This enables us to log even before ORT logging has been initialized. + static void LogEarlyWarning(std::string_view message); + CPUIDInfo(); + + void VendorInfoInit(); + +#if defined(CPUIDINFO_ARCH_X86) + + void X86Init(); + +#elif defined(CPUIDINFO_ARCH_ARM) + +#if defined(__linux__) + + void ArmLinuxInit(); + +#elif defined(_WIN32) + + void ArmWindowsInit(); + +#elif defined(__APPLE__) + + void ArmAppleInit(); + +#endif + +#endif // defined(CPUIDINFO_ARCH_ARM) + +#if defined(CPUINFO_SUPPORTED) + bool pytorch_cpuinfo_init_{false}; +#endif // defined(CPUINFO_SUPPORTED) + bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -132,37 +165,6 @@ class CPUIDInfo { std::string vendor_; uint32_t vendor_id_; - - uint32_t GetVendorId(const std::string& vendor); - -#if defined(CPUIDINFO_ARCH_X86) - - void X86Init(); - std::string GetX86Vendor(int32_t* data); - -#elif defined(CPUIDINFO_ARCH_ARM) - -#if defined(CPUINFO_SUPPORTED) - // Now the following var is only used in ARM build, but later on we may expand the usage. - bool pytorch_cpuinfo_init_{false}; -#endif // defined(CPUINFO_SUPPORTED) - -#if defined(__linux__) - - void ArmLinuxInit(); - -#elif defined(_WIN32) - - void ArmWindowsInit(); - std::string GetArmWindowsVendor(); - -#elif defined(__APPLE__) - - void ArmAppleInit(); - -#endif - -#endif // defined(CPUIDINFO_ARCH_ARM) }; } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info_vendor.cc b/onnxruntime/core/common/cpuid_info_vendor.cc new file mode 100644 index 0000000000000..d4d940eedfe28 --- /dev/null +++ b/onnxruntime/core/common/cpuid_info_vendor.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/cpuid_info.h" + +#include +#include +#include + +#if defined(CPUINFO_SUPPORTED) +#include "cpuinfo.h" +#endif + +namespace { + +#if !defined(CPUINFO_SUPPORTED) + +// The `cpuinfo_vendor` enum is defined by the cpuinfo library. +// In case we don't build with cpuinfo, we define our own copy. +// The enum was copied from here: +// https://github.com/pytorch/cpuinfo/blob/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6/include/cpuinfo.h#L154-L307 + +/** Vendor of processor core design */ +enum cpuinfo_vendor { + /** Processor vendor is not known to the library, or the library failed + to get vendor information from the OS. */ + cpuinfo_vendor_unknown = 0, + + /* Active vendors of modern CPUs */ + + /** + * Intel Corporation. Vendor of x86, x86-64, IA64, and ARM processor + * microarchitectures. + * + * Sold its ARM design subsidiary in 2006. The last ARM processor design + * was released in 2004. + */ + cpuinfo_vendor_intel = 1, + /** Advanced Micro Devices, Inc. Vendor of x86 and x86-64 processor + microarchitectures. */ + cpuinfo_vendor_amd = 2, + /** ARM Holdings plc. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_arm = 3, + /** Qualcomm Incorporated. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_qualcomm = 4, + /** Apple Inc. Vendor of ARM and ARM64 processor microarchitectures. */ + cpuinfo_vendor_apple = 5, + /** Samsung Electronics Co., Ltd. Vendir if ARM64 processor + microarchitectures. */ + cpuinfo_vendor_samsung = 6, + /** Nvidia Corporation. Vendor of ARM64-compatible processor + microarchitectures. */ + cpuinfo_vendor_nvidia = 7, + /** MIPS Technologies, Inc. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_mips = 8, + /** International Business Machines Corporation. Vendor of PowerPC + processor microarchitectures. */ + cpuinfo_vendor_ibm = 9, + /** Ingenic Semiconductor. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_ingenic = 10, + /** + * VIA Technologies, Inc. Vendor of x86 and x86-64 processor + * microarchitectures. + * + * Processors are designed by Centaur Technology, a subsidiary of VIA + * Technologies. + */ + cpuinfo_vendor_via = 11, + /** Cavium, Inc. Vendor of ARM64 processor microarchitectures. */ + cpuinfo_vendor_cavium = 12, + /** Broadcom, Inc. Vendor of ARM processor microarchitectures. */ + cpuinfo_vendor_broadcom = 13, + /** Applied Micro Circuits Corporation (APM). Vendor of ARM64 processor + microarchitectures. */ + cpuinfo_vendor_apm = 14, + /** + * Huawei Technologies Co., Ltd. Vendor of ARM64 processor + * microarchitectures. + * + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, + /** + * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor + * of x86-64 processor microarchitectures. + * + * Processors are variants of AMD cores. + */ + cpuinfo_vendor_hygon = 16, + /** SiFive, Inc. Vendor of RISC-V processor microarchitectures. */ + cpuinfo_vendor_sifive = 17, + + /* Active vendors of embedded CPUs */ + + /** Texas Instruments Inc. Vendor of ARM processor microarchitectures. + */ + cpuinfo_vendor_texas_instruments = 30, + /** Marvell Technology Group Ltd. Vendor of ARM processor + * microarchitectures. + */ + cpuinfo_vendor_marvell = 31, + /** RDC Semiconductor Co., Ltd. Vendor of x86 processor + microarchitectures. */ + cpuinfo_vendor_rdc = 32, + /** DM&P Electronics Inc. Vendor of x86 processor microarchitectures. */ + cpuinfo_vendor_dmp = 33, + /** Motorola, Inc. Vendor of PowerPC and ARM processor + microarchitectures. */ + cpuinfo_vendor_motorola = 34, + + /* Defunct CPU vendors */ + + /** + * Transmeta Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 2004. + * Transmeta processors implemented VLIW ISA and used binary translation + * to execute x86 code. + */ + cpuinfo_vendor_transmeta = 50, + /** + * Cyrix Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1996. + */ + cpuinfo_vendor_cyrix = 51, + /** + * Rise Technology. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1999. + */ + cpuinfo_vendor_rise = 52, + /** + * National Semiconductor. Vendor of x86 processor microarchitectures. + * + * Sold its x86 design subsidiary in 1999. The last processor design was + * released in 1998. + */ + cpuinfo_vendor_nsc = 53, + /** + * Silicon Integrated Systems. Vendor of x86 processor + * microarchitectures. + * + * Sold its x86 design subsidiary in 2001. The last processor design was + * released in 2001. + */ + cpuinfo_vendor_sis = 54, + /** + * NexGen. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1994. + * NexGen designed the first x86 microarchitecture which decomposed x86 + * instructions into simple microoperations. + */ + cpuinfo_vendor_nexgen = 55, + /** + * United Microelectronics Corporation. Vendor of x86 processor + * microarchitectures. + * + * Ceased x86 in the early 1990s. The last processor design was released + * in 1991. Designed U5C and U5D processors. Both are 486 level. + */ + cpuinfo_vendor_umc = 56, + /** + * Digital Equipment Corporation. Vendor of ARM processor + * microarchitecture. + * + * Sold its ARM designs in 1997. The last processor design was released + * in 1997. + */ + cpuinfo_vendor_dec = 57, +}; + +#endif // !defined(CPUINFO_SUPPORTED) + +} // namespace + +namespace onnxruntime { + +namespace { + +struct CpuVendorInfo { + cpuinfo_vendor vendor; + std::string_view name; + uint32_t id; +}; + +constexpr auto kUnknownCpuVendorInfo = CpuVendorInfo{cpuinfo_vendor_unknown, "unknown", 0x0000}; + +constexpr std::array kCpuVendorInfos{ + CpuVendorInfo{cpuinfo_vendor_amd, "AMD", 0x1022}, + CpuVendorInfo{cpuinfo_vendor_intel, "Intel", 0x8086}, + CpuVendorInfo{cpuinfo_vendor_qualcomm, "Qualcomm", uint32_t{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}}, + CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE}, + CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B}, + CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5}, + + // TODO add more as needed +}; + +const CpuVendorInfo* FindCpuVendorInfo(cpuinfo_vendor vendor) { + const auto vendor_mapping_it = std::find_if(kCpuVendorInfos.begin(), kCpuVendorInfos.end(), + [vendor](const CpuVendorInfo& entry) { + return entry.vendor == vendor; + }); + + if (vendor_mapping_it != kCpuVendorInfos.end()) { + return &*vendor_mapping_it; + } + + return nullptr; +} + +} // namespace + +void CPUIDInfo::VendorInfoInit() { + const cpuinfo_vendor vendor = [&]() { + cpuinfo_vendor result = cpuinfo_vendor_unknown; +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + const auto* processor = cpuinfo_get_processor(0); + if (processor && processor->core) { + result = processor->core->vendor; + } + } +#endif // defined(CPUINFO_SUPPORTED) + return result; + }(); + + const auto* vendor_info = FindCpuVendorInfo(vendor); + if (vendor_info == nullptr) { + LogEarlyWarning(MakeString("Unknown CPU vendor. cpuinfo_vendor value: ", static_cast(vendor))); + vendor_info = &kUnknownCpuVendorInfo; + } + + vendor_ = vendor_info->name; + vendor_id_ = vendor_info->id; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index c2e26f629330f..d8d943d6e9a41 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -61,10 +61,11 @@ inline void TrimStringFromRight(std::string& s) { * @param s The string to trim. * @return The trimmed string. */ -inline std::string TrimString(std::string s) { - TrimStringFromRight(s); - TrimStringFromLeft(s); - return s; +inline std::string TrimString(std::string_view s) { + std::string s_trimmed{s}; + TrimStringFromRight(s_trimmed); + TrimStringFromLeft(s_trimmed); + return s_trimmed; } /** diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index efc12ef8dd0e8..421e5a6db51b7 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -22,6 +22,7 @@ #include "core/graph/model.h" #include "core/graph/model_saving_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -909,6 +910,34 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } + // Generate EP compatibility strings for OrtEp types and add to model metadata + // At this point, the graph has been populated with all the EPContext nodes + { + ORT_RETURN_IF_ERROR(ep_graph.Resolve()); + const GraphViewer graph_viewer(ep_graph); + for (const auto& ep : execution_providers) { + try { + // Generate the compatibility string for this EP + std::string compatibility_string = ep->GetCompiledModelCompatibilityInfo(graph_viewer); + if (!compatibility_string.empty()) { + // Create a unique key for this EP's compatibility info + // Use format: "ep_compatibility_info." + // All EPs in a session must have a unique Type() value, so this will be unique for the generated model + std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep->Type(); + auto& model_metadata = ep_context_model.MetaData(); + auto [it, was_inserted] = + model_metadata.insert_or_assign(metadata_key, compatibility_string); + if (!was_inserted) { + LOGS(logger, WARNING) << "Overwriting existing EP compatibility info for key: " << metadata_key << " (EP: " << ep->Type() << ")"; + } + LOGS(logger, VERBOSE) << "Added EP compatibility info for " << ep->Type() << " with key: " << metadata_key; + } + } catch (const std::exception& ex) { + LOGS(logger, WARNING) << "Failed to generate compatibility string for EP " << ep->Type() << ": " << ex.what(); + } + } + } + size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold; std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path; bool force_embed_external_ini = false; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 504b102e782fd..b99c22edb36c8 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -10,6 +10,7 @@ #include "core/framework/tensor_external_data_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" #define DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(external_type, internal_type, internal_api) \ external_type* ToExternal() { return static_cast(this); } \ @@ -301,6 +302,11 @@ struct OrtGraph { /// The graph's name. virtual const std::string& GetName() const = 0; + /// + /// Returns the model's metadata. + /// + /// The model metadata. + virtual std::unique_ptr GetModelMetadata() const = 0; /// /// Returns the model's path, which is empty if unknown. /// diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5511275239e45..46a52e042ba13 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3664,10 +3664,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h uint32_t components = (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) ? (8 / bits) : 1; for (int i = 0; i < r; ++i) { - if (!data_shape.dim(i).has_dim_value() || - !scales_shape.dim(i).has_dim_value() || - (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || - (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) { + if (data_shape.dim(i).has_dim_value() && + scales_shape.dim(i).has_dim_value() && + ((i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || + (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value()))) { fail_shape_inference("data shape and scales shape do not match"); } } diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index eb7fb6937c29e..759a2998ace3a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -20,6 +20,7 @@ #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/graph_viewer.h" #include "core/graph/graph.h" +#include "core/graph/model.h" namespace onnxruntime { @@ -769,6 +770,25 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } +std::unique_ptr EpGraph::GetModelMetadata() const { +#if !defined(ORT_MINIMAL_BUILD) + const auto& model = graph_viewer_.GetGraph().GetModel(); + auto model_metadata = std::make_unique(); + + model_metadata->producer_name = model.ProducerName(); + model_metadata->producer_version = model.ProducerVersion(); + model_metadata->description = model.DocString(); + model_metadata->graph_description = model.GraphDocString(); + model_metadata->domain = model.Domain(); + model_metadata->version = model.ModelVersion(); + model_metadata->custom_metadata_map = model.MetaData(); + model_metadata->graph_name = model.MainGraph().Name(); + return model_metadata; +#else + return nullptr; +#endif +} + const ORTCHAR_T* EpGraph::GetModelPath() const { return graph_viewer_.ModelPath().c_str(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index be78d77360cb8..7f22e265129f7 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -298,6 +298,9 @@ struct EpGraph : public OrtGraph { // Returns the graph's name. const std::string& GetName() const override; + // Returns the graph's metadata + std::unique_ptr GetModelMetadata() const override; + // Returns the model path. const ORTCHAR_T* GetModelPath() const override; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 436af7115eb1a..eb5e1e89e2f9c 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -361,6 +361,10 @@ const ModelMetaData& Model::MetaData() const noexcept { return model_metadata_; } +ModelMetaData& Model::MetaData() noexcept { + return model_metadata_; +} + Graph& Model::MainGraph() noexcept { return *graph_; } @@ -377,6 +381,15 @@ ModelProto Model::ToProto() const { // out dense duplicates of sparse initializers and leave the original // proto intact. ModelProto result(model_proto_); + + // Sync current model_metadata_ back to protobuf metadata_props + result.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{result.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProto(); return result; @@ -386,6 +399,15 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa const std::filesystem::path& file_path, const ModelSavingOptions& model_saving_options) const { ModelProto result(model_proto_); + + // Sync current model_metadata_ back to protobuf metadata_props + result.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{result.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 70f82bcfb160b..e8722f6f5c0b2 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -189,6 +189,8 @@ class Model { const ModelMetaData& MetaData() const noexcept; + ModelMetaData& MetaData() noexcept; + // Gets the path from which the model was loaded, if any. const std::filesystem::path& ModelPath() const noexcept { return model_path_; } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index d3795d911b22f..e7ffcbc7e4c90 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -13,6 +13,7 @@ #include "core/framework/ort_value.h" #include "core/graph/abi_graph_types.h" #include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" namespace onnxruntime { @@ -184,6 +185,9 @@ struct ModelEditorGraph : public OrtGraph { const std::string& GetName() const override { return name; } + std::unique_ptr GetModelMetadata() const override { + return std::make_unique(model_metadata); + } const ORTCHAR_T* GetModelPath() const override { return model_path.c_str(); } int64_t GetOnnxIRVersion() const override { @@ -241,6 +245,7 @@ struct ModelEditorGraph : public OrtGraph { std::vector> nodes; std::string name = "ModelEditorGraph"; std::filesystem::path model_path; + ModelMetadata model_metadata; }; } // namespace onnxruntime diff --git a/onnxruntime/core/platform/apple/device_discovery.cc b/onnxruntime/core/platform/apple/device_discovery.cc new file mode 100644 index 0000000000000..767b834e38756 --- /dev/null +++ b/onnxruntime/core/platform/apple/device_discovery.cc @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/device_discovery.h" + +#include +#include + +#include "core/common/logging/logging.h" + +namespace onnxruntime { + +namespace { + +constexpr auto kApplePciVendorId = 0x106B; +constexpr auto kAppleVendorName = "Apple"; + +std::vector GetGpuDevices() { + std::vector result{}; + + // For now, we assume the existence of one GPU if it is a Mac with Apple Silicon. + // TODO support iOS + // TODO support Intel Macs which may have more than one GPU +#if TARGET_OS_OSX && TARGET_CPU_ARM64 + { + OrtHardwareDevice gpu_device{}; + gpu_device.type = OrtHardwareDeviceType_GPU; + gpu_device.vendor_id = kApplePciVendorId; + gpu_device.vendor = kAppleVendorName; + + result.emplace_back(std::move(gpu_device)); + } +#endif // TARGET_OS_OSX && TARGET_CPU_ARM64 + + return result; +} + +bool HasAppleNeuralEngine() { + // Copied from onnxruntime/core/providers/coreml/builders/helper.cc:HasNeuralEngine(). + bool has_apple_neural_engine = false; + + struct utsname system_info; + uname(&system_info); + LOGS_DEFAULT(VERBOSE) << "Current Apple hardware info: " << system_info.machine; + +#if TARGET_OS_IPHONE + // utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2". + // Since Neural Engine is only available for use on A12 and later, major device version in the + // identifier is checked for these models: + // A12: iPhone XS (11,2), iPad Mini - 5th Gen (11,1) + // A12X: iPad Pro - 3rd Gen (8,1) + // For more information, see https://www.theiphonewiki.com/wiki/Models + size_t str_len = strnlen(system_info.machine, onnxruntime::kMaxStrLen); + if (str_len > 4 && strncmp("iPad", system_info.machine, 4) == 0) { + const int major_version = atoi(system_info.machine + 4); + has_apple_neural_engine = major_version >= 8; // There are no device between iPad 8 and 11. + } else if (str_len > 6 && strncmp("iPhone", system_info.machine, 6) == 0) { + const int major_version = atoi(system_info.machine + 6); + has_apple_neural_engine = major_version >= 11; + } +#elif TARGET_OS_OSX && TARGET_CPU_ARM64 + // Only Mac with arm64 CPU (Apple Silicon) has ANE. + has_apple_neural_engine = true; +#endif // #if TARGET_OS_IPHONE + + return has_apple_neural_engine; +} + +std::vector GetNpuDevices() { + std::vector result{}; + + if (HasAppleNeuralEngine()) { + OrtHardwareDevice npu_device{}; + npu_device.type = OrtHardwareDeviceType_NPU; + npu_device.vendor_id = kApplePciVendorId; + npu_device.vendor = kAppleVendorName; + + result.emplace_back(std::move(npu_device)); + } + + return result; +} + +} // namespace + +std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { + std::unordered_set devices; + + // get CPU devices + devices.insert(GetCpuDeviceFromCPUIDInfo()); + + // get GPU devices + { + auto gpu_devices = GetGpuDevices(); + devices.insert(gpu_devices.begin(), gpu_devices.end()); + } + + // get NPU devices + { + auto npu_devices = GetNpuDevices(); + devices.insert(npu_devices.begin(), npu_devices.end()); + } + + return devices; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/device_discovery.h b/onnxruntime/core/platform/device_discovery.h index 70be10bf09e4e..b49e63b90236a 100644 --- a/onnxruntime/core/platform/device_discovery.h +++ b/onnxruntime/core/platform/device_discovery.h @@ -3,25 +3,24 @@ #pragma once -#include #include #include "core/session/abi_devices.h" + namespace onnxruntime { class DeviceDiscovery { public: - static std::unordered_set& GetDevices() { - // assumption: devices don't change. we assume the machine must be shutdown to change cpu/gpu/npu devices. - // technically someone could disable/enable a device in a running OS. we choose not to add complexity to support - // that scenario. - static std::unordered_set devices(DiscoverDevicesForPlatform()); - return devices; - } + static const std::unordered_set& GetDevices(); private: DeviceDiscovery() = default; + // platform specific code implements this method static std::unordered_set DiscoverDevicesForPlatform(); + + // Gets a CPU device by querying `CPUIDInfo`. + static OrtHardwareDevice GetCpuDeviceFromCPUIDInfo(); }; + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/device_discovery_common.cc b/onnxruntime/core/platform/device_discovery_common.cc new file mode 100644 index 0000000000000..dcba31aed6fec --- /dev/null +++ b/onnxruntime/core/platform/device_discovery_common.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file contains platform-agnostic device discovery implementation. + +#include "core/platform/device_discovery.h" + +#include + +#include "core/common/cpuid_info.h" +#include "core/common/logging/logging.h" + +namespace onnxruntime { + +const std::unordered_set& DeviceDiscovery::GetDevices() { + // assumption: devices don't change. we assume the machine must be shutdown to change cpu/gpu/npu devices. + // technically someone could disable/enable a device in a running OS. we choose not to add complexity to support + // that scenario. + static std::unordered_set devices = []() { + auto discovered_devices = DiscoverDevicesForPlatform(); + + // log discovered devices + for (const auto& ortdevice : discovered_devices) { + std::ostringstream oss; + oss << "Discovered OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id + << ", device_id:0x" << ortdevice.device_id + << ", vendor:" << ortdevice.vendor + << ", type:" << std::dec << static_cast(ortdevice.type) + << ", metadata: ["; + for (auto& [key, value] : ortdevice.metadata.Entries()) { + oss << key << "=" << value << ", "; + } + oss << "]}"; + LOGS_DEFAULT(INFO) << oss.str(); + } + + return discovered_devices; + }(); + + return devices; +} + +OrtHardwareDevice DeviceDiscovery::GetCpuDeviceFromCPUIDInfo() { + const auto& cpuid_info = CPUIDInfo::GetCPUIDInfo(); + + OrtHardwareDevice cpu_device{}; + cpu_device.vendor = cpuid_info.GetCPUVendor(); + cpu_device.vendor_id = cpuid_info.GetCPUVendorId(); + cpu_device.device_id = 0; + cpu_device.type = OrtHardwareDeviceType_CPU; + + return cpu_device; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/posix/device_discovery.cc b/onnxruntime/core/platform/device_discovery_default.cc similarity index 57% rename from onnxruntime/core/platform/posix/device_discovery.cc rename to onnxruntime/core/platform/device_discovery_default.cc index 82564539ab5d4..73ddf516034ab 100644 --- a/onnxruntime/core/platform/posix/device_discovery.cc +++ b/onnxruntime/core/platform/device_discovery_default.cc @@ -4,14 +4,16 @@ #include "core/platform/device_discovery.h" namespace onnxruntime { + std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { - std::unordered_set devices; - // get CPU devices + // This is a default implementation. + // We assume that there is a CPU device and do not attempt to discover anything else. - // get GPU devices + std::unordered_set devices{}; - // get NPU devices + devices.emplace(GetCpuDeviceFromCPUIDInfo()); return devices; } + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc new file mode 100644 index 0000000000000..6a02a1b46028f --- /dev/null +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/device_discovery.h" + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/logging/logging.h" +#include "core/common/parse_string.h" +#include "core/common/string_utils.h" + +namespace fs = std::filesystem; + +namespace onnxruntime { + +namespace { + +Status ErrorCodeToStatus(const std::error_code& ec) { + if (!ec) { + return Status::OK(); + } + + return Status{common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL, + MakeString("Error: std::error_code with category name: ", ec.category().name(), + ", value: ", ec.value(), ", message: ", ec.message())}; +} + +struct GpuSysfsPathInfo { + size_t card_idx; + fs::path path; +}; + +Status DetectGpuSysfsPaths(std::vector& gpu_sysfs_paths_out) { + std::error_code error_code{}; + const fs::path sysfs_class_drm_path = "/sys/class/drm"; + const bool sysfs_class_drm_path_exists = fs::exists(sysfs_class_drm_path, error_code); + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code)); + + if (!sysfs_class_drm_path_exists) { + gpu_sysfs_paths_out = std::vector{}; + return Status::OK(); + } + + const auto detect_card_path = [](const fs::path& sysfs_path, size_t& card_idx) -> bool { + const auto filename = sysfs_path.filename(); + const auto filename_str = std::string_view{filename.native()}; + + // Look for a filename matching "cardN". N is a number. + constexpr std::string_view prefix = "card"; + if (filename_str.find(prefix) != 0) { + return false; + } + + size_t parsed_card_idx{}; + if (!TryParseStringWithClassicLocale(filename_str.substr(prefix.size()), parsed_card_idx)) { + return false; + } + + card_idx = parsed_card_idx; + return true; + }; + + std::vector gpu_sysfs_paths{}; + + auto dir_iterator = fs::directory_iterator{sysfs_class_drm_path, error_code}; + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code)); + + for (const auto& dir_item : dir_iterator) { + const auto& dir_item_path = dir_item.path(); + + if (size_t card_idx{}; detect_card_path(dir_item_path, card_idx)) { + GpuSysfsPathInfo path_info{}; + path_info.card_idx = card_idx; + path_info.path = dir_item_path; + gpu_sysfs_paths.emplace_back(std::move(path_info)); + } + } + + gpu_sysfs_paths_out = std::move(gpu_sysfs_paths); + return Status::OK(); +} + +Status ReadFileContents(const fs::path& file_path, std::string& contents) { + std::ifstream file{file_path}; + ORT_RETURN_IF_NOT(file, "Failed to open file: ", file_path); + std::istreambuf_iterator file_begin{file}, file_end{}; + contents.assign(file_begin, file_end); + return Status::OK(); +} + +template +Status ReadValueFromFile(const fs::path& file_path, ValueType& value) { + std::string file_text{}; + ORT_RETURN_IF_ERROR(ReadFileContents(file_path, file_text)); + file_text = utils::TrimString(file_text); + return ParseStringWithClassicLocale(file_text, value); +} + +Status GetGpuDeviceFromSysfs(const GpuSysfsPathInfo& path_info, OrtHardwareDevice& gpu_device_out) { + OrtHardwareDevice gpu_device{}; + const auto& sysfs_path = path_info.path; + + // vendor id + { + const auto vendor_id_path = sysfs_path / "device" / "vendor"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(vendor_id_path, gpu_device.vendor_id)); + } + + // TODO vendor name + + // device id + { + const auto device_id_path = sysfs_path / "device" / "device"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(device_id_path, gpu_device.device_id)); + } + + // metadata + gpu_device.metadata.Add("card_idx", MakeString(path_info.card_idx)); + // TODO is card discrete? + + gpu_device.type = OrtHardwareDeviceType_GPU; + + gpu_device_out = std::move(gpu_device); + return Status::OK(); +} + +Status GetGpuDevices(std::vector& gpu_devices_out) { + std::vector gpu_sysfs_path_infos{}; + ORT_RETURN_IF_ERROR(DetectGpuSysfsPaths(gpu_sysfs_path_infos)); + + std::vector gpu_devices{}; + gpu_devices.reserve(gpu_sysfs_path_infos.size()); + + for (const auto& gpu_sysfs_path_info : gpu_sysfs_path_infos) { + OrtHardwareDevice gpu_device{}; + ORT_RETURN_IF_ERROR(GetGpuDeviceFromSysfs(gpu_sysfs_path_info, gpu_device)); + gpu_devices.emplace_back(std::move(gpu_device)); + } + + gpu_devices_out = std::move(gpu_devices); + return Status::OK(); +} + +} // namespace + +std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { + std::unordered_set devices; + + // get CPU devices + devices.emplace(GetCpuDeviceFromCPUIDInfo()); + + // get GPU devices + { + std::vector gpu_devices{}; + Status gpu_device_discovery_status = GetGpuDevices(gpu_devices); + if (gpu_device_discovery_status.IsOK()) { + devices.insert(std::make_move_iterator(gpu_devices.begin()), + std::make_move_iterator(gpu_devices.end())); + } else { + LOGS_DEFAULT(WARNING) << "GPU device discovery failed: " << gpu_device_discovery_status.ErrorMessage(); + } + } + + // get NPU devices + // TODO figure out how to discover these + + return devices; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index ff904ddb3e7e0..cf761f587ad0b 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -635,19 +635,6 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - std::ostringstream oss; - oss << "Adding OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id - << ", device_id:0x" << ortdevice.device_id - << ", vendor:" << ortdevice.vendor - << ", type:" << std::dec << static_cast(ortdevice.type) - << ", metadata: ["; - for (auto& [key, value] : ortdevice.metadata.Entries()) { - oss << key << "=" << value << ", "; - } - - oss << "]}" << std::endl; - LOGS_DEFAULT(INFO) << oss.str(); - return ortdevice; }; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index cc9d9f3da1d81..451be69c81cfb 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -11,6 +11,7 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" +#include "core/framework/ort_value.h" #include "nv_execution_provider.h" #include "nv_execution_provider_utils.h" #include "nv_execution_provider_custom_ops.h" @@ -487,7 +488,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { std::string error_input_name = input_name; std::string error_msg = - "Nv EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg)); } @@ -510,7 +511,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { std::string error_input_name = input_name; std::string error_msg = - "Nv EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + error_input_name + "'"; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg)); } @@ -532,7 +533,7 @@ Status BindContextInput(Ort::KernelContext& ctx, if (!trt_context->setInputShape(input_name, dims)) { std::string error_input_name = input_name; ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); + "NvTensorRTRTX EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); } // Bind "execution tensor" input buffer @@ -553,7 +554,7 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); + "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); } } trt_context->setTensorAddress(input_name, data); @@ -644,7 +645,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP output tensor data type: " + std::to_string(output_type) + " not supported."); + "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } trt_context->setTensorAddress(output_name, buffers[output_name]); @@ -707,7 +708,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP output tensor data type: " + std::to_string(output_type) + " not supported."); + "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } return Status::OK(); @@ -836,7 +837,12 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) cudaDeviceProp prop; CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_)); - compute_capability_ = GetComputeCapacity(prop); + auto cc = prop.major * 10 + prop.minor; + if (!(cc == 86 || cc == 89 || cc >= 120)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] The execution provider only supports RTX devices with compute capabilities 86, 89, 120 and above")); + } + compute_capability_ = GetComputeCapability(prop); if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); @@ -866,6 +872,15 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) "When providing either 'trt_onnx_bytestream_size' or " "'trt_onnx_bytestream' both have to be provided")); } + use_external_data_initializer_ = info.use_external_data_initializer; + onnx_external_data_bytestream_ = info.external_data_bytestream; + onnx_external_data_bytestream_size_ = info.external_data_bytestream_size; + if ((onnx_external_data_bytestream_ != nullptr && onnx_external_data_bytestream_size_ == 0) || + (onnx_external_data_bytestream_ == nullptr && onnx_external_data_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'onnx_external_data_bytestream_size' or " + "'onnx_external_data_bytestream' both have to be provided")); + } detailed_build_log_ = info.detailed_build_log; dump_ep_context_model_ = info.dump_ep_context_model; ep_context_file_path_ = info.ep_context_file_path; @@ -979,13 +994,13 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); if (handle == nullptr) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not open shared library from " + engine_decryption_lib_path_)); + "NvTensorRTRTX EP could not open shared library from " + engine_decryption_lib_path_)); } engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); if (engine_decryption_ == nullptr) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); + "NvTensorRTRTX EP could not find decryption function in shared library from " + engine_decryption_lib_path_)); } } @@ -1029,6 +1044,8 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_ep_context_embed_mode: " << ep_context_embed_mode_ << ", nv_cache_prefix: " << cache_prefix_ << ", nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ + << ", nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_ + << ", nv_use_external_data_initializer_: " << use_external_data_initializer_ << ", nv_op_types_to_exclude: " << op_types_to_exclude_; } @@ -1093,7 +1110,7 @@ void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { } std::vector NvExecutionProvider::CreatePreferredAllocators() { - OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kSameAsRequested), + OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kNextPowerOfTwo), -1, -1, -1, -1); AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, @@ -1140,6 +1157,9 @@ nvinfer1::IBuilder* NvExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) { auto lock = GetApiLock(); builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + unsigned int num_threads = std::thread::hardware_concurrency(); + builder_->setMaxThreads(num_threads / 2); + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Set threads that the builder can use to:" << builder_->getMaxThreads(); } } return builder_.get(); @@ -1450,8 +1470,11 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t SetAllGraphInputs(graph_build); } - ORT_ENFORCE(graph_build.Resolve().IsOK()); - + auto status = graph_build.Resolve(); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << status.ErrorMessage(); + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX graph resolve failed: " + status.ErrorMessage())); + } // Add parent graph output to the subgraph int i = 0; std::vector subgraph_outputs; @@ -1502,7 +1525,37 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + + // save user provided external data in memory instead of writing to ModelProto + // needed for models > 2GB + std::vector userWeights; + if (use_external_data_initializer_) { + auto c_api = Ort::GetApi(); + const InitializedTensorSet& allInitializers = graph_viewer->GetAllInitializedTensors(); + userWeights.reserve(allInitializers.size()); + for (auto& entry : allInitializers) { + OrtValue initializer_value; + auto* tp = entry.second; + if (utils::HasRawData(*tp)) { + userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); + } else if (graph_viewer->GetOrtValueInitializer(tp->name(), initializer_value)) { + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + size_t size = 0; + const void* ptr = nullptr; + c_api.GetTensorSizeInBytes(&initializer_value, &size); + c_api.GetTensorData(&initializer_value, &ptr); + userWeights.emplace_back(tp->name(), ptr, size); + } else if (utils::HasExternalDataInMemory(*tp)) { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(std::move(full_init->name()), std::move(full_init->raw_data())); + } + } + } + + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -1521,11 +1574,25 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t auto network_flags = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + bool is_model_supported = false; + // limit the scope of trt_parser so that model gets unloaded from memory asap { auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + if (use_external_data_initializer_) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); + } + is_model_supported = trt_parser->parseModelProto(); +#else + ORT_THROW("'nv_use_external_data_initializer' is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } else { + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + } // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior. auto num_subgraphs = trt_parser->getNbSubgraphs(); @@ -1708,21 +1775,33 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, #endif model_path_[sizeof(model_path_) - 1] = '\0'; - // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and - // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. - // So, simply return the ComputeCapability here. - if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) { - SubGraph_t supported_node_vector = {{0}, true}; - std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)), 0); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - return result; - } + const int number_of_ort_nodes = graph.NumberOfNodes(); + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); // Generate unique kernel name for TRT graph HashValue model_hash = TRTGenerateId(graph, std::to_string(trt_version_), std::to_string(cuda_version_)); - // Get supported node list from TensorRT parser - const int number_of_ort_nodes = graph.NumberOfNodes(); + // If there are "EPContext" contrib op nodes, it means TRT EP can fetch the precompiled engine info from the node and + // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT + // parser and engine compilation. So, simply return subgraphs consists of single ep context nodes here. + int subgraph_idx = 0; + for (size_t node_idx : node_index) { + const auto& node = graph.GetNode(node_idx); + const bool is_context_node = node && !node->OpType().empty() && node->OpType() == EPCONTEXT_OP; + if (is_context_node) { + SubGraph_t supported_node_vector(std::make_pair(std::vector{node_idx}, true)); + std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, model_hash, subgraph_idx++); + + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } + } + // return early if context nodes where found + if (!result.empty()) { + return result; + } + + // For regular ONNX nodes, get supported node list from TensorRT parser + std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); @@ -1741,7 +1820,6 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, auto exclude_ops_set = get_exclude_ops_set(op_types_to_exclude_); SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); bool new_subgraph = true; /* Iterate all the nodes and exclude the node if: @@ -1932,14 +2010,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, */ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, bool detailed_build_log) { bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + bool refit_with_external_data = onnx_external_data_bytestream != nullptr && onnx_external_data_bytestream_size != 0; + bool refit_complete = false; std::filesystem::path onnx_model_path{onnx_model_folder_path}; if (refit_from_file) { if (!onnx_model_filename.empty()) { @@ -1976,34 +2056,145 @@ common::Status NvExecutionProvider::RefitEngine(std::string onnx_model_filename, auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (refit_from_file) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + + // New refit APIs + if (refit_with_external_data) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + // A valid model bytestream must be passed. + if (refit_from_file) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "NvTensorRTRTX EP's refit with external data must be called with a valid ONNX model bytestream"); } - } else { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; - if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + + if (!parser_refitter->loadModelProto(onnx_model_bytestream, onnx_model_bytestream_size, nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not load model from provided onnx_model_bytestream"); + } + + // Extract weight information from the Refitter. + int required_weights = refitter->getAllWeights(0, nullptr); + std::vector refit_names_prealocated(required_weights); + refitter->getAllWeights(required_weights, refit_names_prealocated.data()); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitter requires " << required_weights << " weights"; + std::unordered_set refit_names(std::make_move_iterator(refit_names_prealocated.begin()), + std::make_move_iterator(refit_names_prealocated.end())); + + // Vectors to keep track of data pointers. + std::vector names; + names.reserve(required_weights); + std::vector bytes; + bytes.reserve(required_weights); + std::vector sizes; + sizes.reserve(required_weights); + + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; + + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + if (refit_names.find(proto_name) != refit_names.end()) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); + } + } + names.push_back(proto.name()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); + } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); + } else { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Proto: " + proto_name + " has no raw nor external data."; + ++missing_initializer_data; + } + } else { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Initializer with name: " << proto_name << " was not marked as refittable"; + } + } + if (missing_initializer_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[NvTensorRTRTX EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."); + } + + // Load extracted initializers into the parser + if (!names.empty()) { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Number of initializers submitted to refitter " << names.size(); + for (size_t i = 0; i < names.size(); i++) { + bool refloadInit = parser_refitter->loadInitializer(names[i].c_str(), bytes[i], sizes[i]); + if (!refloadInit) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } + } + } + // Perform refit. + if (!parser_refitter->refitModelProto()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter refitModelProto() failed with the provided external data bytestream."); + } + refit_complete = true; +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Refit with external data is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } + + // If new refit flow was not completed, then fallback to refit_from_file. + if (!refit_complete) { + if (refit_from_file) { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } } } if (refitter->refitCudaEngine()) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Successfully refitted the weight-stripped engine."; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "NvTensorRTRTX EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); } - // serialize the refitted engine to disk - if (serialize_refitted_engine) { - std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); - nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); - std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); - engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Serialize the refitted engine to " << refitted_engine_cache; - } return Status::OK(); } @@ -2029,8 +2220,10 @@ common::Status NvExecutionProvider::Compile(const std::vector } Status status; - if (GraphHasCtxNode(graph_body_viewer)) { + size_t node_idx = 0; + if (GraphHasCtxNode(graph_body_viewer, node_idx)) { status = CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer, + node_idx, fused_node, input_map, output_map, @@ -2135,6 +2328,16 @@ static bool IsIOBindingRequired(TRTState* const trt_state, const Ort::KernelCont return require_io_binding; } +const InlinedVector NvExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_nodes; + if (ep_context_model_) { + for (auto* node : ep_context_model_->MainGraph().Nodes()) { + ep_context_nodes.push_back(node); + } + } + return ep_context_nodes; +} + Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer, const Node& fused_node, std::unordered_map& input_map, @@ -2144,11 +2347,38 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); + // exclude weights if external + std::vector userWeights; + if (use_external_data_initializer_) { + auto c_api = Ort::GetApi(); + const InitializedTensorSet& allInitializers = graph_body_viewer.GetAllInitializedTensors(); + userWeights.reserve(allInitializers.size()); + for (auto& entry : allInitializers) { + OrtValue initializer_value; + auto* tp = entry.second; + if (utils::HasRawData(*tp)) { + userWeights.emplace_back(TensorrtUserWeights(tp->name(), tp->raw_data().data(), tp->raw_data().size())); + } else if (graph_body_viewer.GetOrtValueInitializer(tp->name(), initializer_value)) { + // the initializer was marked as external data by the ORT graph at load time since it was provided in memory + size_t size = 0; + const void* ptr = nullptr; + c_api.GetTensorSizeInBytes(&initializer_value, &size); + c_api.GetTensorData(&initializer_value, &ptr); + userWeights.emplace_back(tp->name(), ptr, size); + } else if (utils::HasExternalDataInMemory(*tp)) { + // only copy and take ownership of the data if none of the above conditions are met + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(TensorrtUserWeights(std::move(full_init->name()), std::move(full_init->raw_data()))); + } + } + } + // ORT's default topological sort is using reversed DFS. // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !use_external_data_initializer_ /*include raw initializers*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); @@ -2165,7 +2395,21 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); - trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + + if (use_external_data_initializer_) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 1 + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); + } + trt_parser->parseModelProto(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "'nv_use_external_data_initializer' is only supported on TensorRT RTX 1.1.x.x and above."); +#endif + } else { + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + } + if (max_workspace_size_ > 0) { trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); } @@ -2329,7 +2573,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr ; } } - std::string trt_node_name_with_precision = fused_node.Name() + "_strong_typed"; // enable sparse weights if (sparsity_enable_) { @@ -2358,32 +2601,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr std::unique_ptr trt_engine; std::unique_ptr trt_context; - std::string cache_path = ""; - std::string cache_suffix = ""; - // Customize cache prefix if assigned - if (!cache_prefix_.empty()) { - // Generate cache suffix in case user would like to customize cache prefix - cache_suffix = "_" + GetCacheSuffix(fused_node.Name(), trt_node_name_with_precision); - cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; - } else { - cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); - } - - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - const std::string cache_path_prefix = cache_path; - std::string engine_cache_path = cache_path_prefix + ".engine"; - const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; - const std::string profile_cache_path = cache_path_prefix + ".profile"; - - // If weight-stripped engine is enabled and refitted engine cache is not present, - // TRT EP will use the engine cache with ".stripped.engine" appended to the end. - const std::filesystem::path engine_cache_fs_path = engine_cache_path; - if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { - engine_cache_path = cache_path_prefix + ".stripped.engine"; - weight_stripped_engine_refit_ = true; - } - // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); @@ -2398,49 +2615,63 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)}; if (serialized_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to create engine from network for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP failed to deserialize engine for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP failed to deserialize engine for fused node: " + fused_node.Name()); } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << fused_node.Name() << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } // dump EP context node model if (dump_ep_context_model_) { // "ep_cache_context" node attribute should be a relative path to context model directory - if (ep_cache_context_attr_.empty()) { - auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); - ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_path = GetCachePath(cache_path_, cache_prefix_) + fused_node.Name() + ".engine"; + ; + } else { + cache_path = GetCachePath(cache_path_, fused_node.Name()) + ".engine"; + ; + } + // NV TRT EP per default generates hardware compatible engines for any RTX device with compute capability > 80 + std::string compute_capability_hw_compat = "80+"; + if (!ep_context_model_) { + ep_context_model_ = Model::Create("nv_trt_rtx_ep_context_model", false, *GetLogger()); + } + + auto status = CreateCtxNode(graph_body_viewer, + ep_context_model_->MainGraph(), + cache_path, + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + ep_context_embed_mode_, + compute_capability_hw_compat, + model_path_, + fused_node.Name(), + trt_version_); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } - std::string compute_capability_hw_compat = compute_capability_ + "+"; - std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, - ep_cache_context_attr_, - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - ep_context_embed_mode_, - compute_capability_hw_compat, - model_path_, - GetLogger())}; - DumpCtxModel(model_proto.get(), ctx_model_path_); } } if (weight_stripped_engine_refit_) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Refit engine from main ONNX file after engine build"; - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, - engine_cache_path, false /* path check for security */, - onnx, - onnx_size, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, trt_engine.get(), - false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -2453,7 +2684,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not build execution context for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); } bool is_dynamic_shape_context = false; @@ -2499,12 +2730,12 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, trt_node_name_with_precision, + input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], engine_decryption_enable_, engine_decryption_, engine_encryption_, detailed_build_log_, sparsity_enable_, - auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_, cache_suffix}; + auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_}; *state = p.release(); return 0; }; @@ -2552,7 +2783,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (multi_profile_enable_ == true) { if (!trt_context->setOptimizationProfileAsync(nv_profile_index_, stream)) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP select an optimization profile for the current context failed"); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP select an optimization profile for the current context failed"); } // Check before using trt_engine @@ -2650,7 +2881,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } if (trt_state->context_memory_size != mem_size) { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; - trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/); + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, true /*use_reserve*/); trt_state->context_memory_size = mem_size; trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } @@ -2666,7 +2897,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP execution context enqueue failed."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); } /* @@ -2743,6 +2974,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + size_t node_idx, const Node& fused_node, std::unordered_map& input_map, std::unordered_map& output_map, @@ -2762,8 +2994,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra onnx_model_folder_path_, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, detailed_build_log_); - auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); + auto status = trt_cache_model_handler.GetEpContextFromGraph(*graph_body_viewer.GetNode(node_idx)); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -2775,7 +3009,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "Nv EP could not build execution context for fused node: " + fused_node.Name()); + "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); } bool is_dynamic_shape_context = false; @@ -2963,7 +3197,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } if (trt_state->context_memory_size != mem_size) { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; - trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/); + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, true /*use_reserve*/); // trt_state->context_memory = IAllocator::MakeUniquePtr(alloc, mem_size, false /*use_reserve*/, stream); trt_state->context_memory_size = mem_size; trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); @@ -2980,7 +3214,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra // Run TRT inference if (!trt_context->enqueueV3(stream)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Nv EP execution context enqueue failed."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "NvTensorRTRTX EP execution context enqueue failed."); } /* diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 83b89a2e9d1fb..e3dd38eb837ff 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -153,6 +153,41 @@ struct TensorParams { } }; +// Data structure to hold user weights when ModelProtos are serialized with external data +class TensorrtUserWeights { + public: + TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), + data_cpy_(data) { + }; + + TensorrtUserWeights(const std::string& name, const void* data, size_t size) : name_(name), data_(data), size_(size) { + }; + + const char* Name() const { + return name_.c_str(); + }; + + const void* Data() const { + if (!data_cpy_.empty()) { + return data_cpy_.data(); + } + return data_; + } + + int64_t Size() const { + if (!data_cpy_.empty()) { + return static_cast(data_cpy_.size()); + } + return static_cast(size_); + } + + private: + std::string name_{}; + std::string data_cpy_{}; + void const* data_; + size_t size_; +}; + // Information to construct kernel function state. struct TensorrtFuncState { AllocateFunc test_allocate_func = nullptr; @@ -168,7 +203,6 @@ struct TensorrtFuncState { std::vector> output_info; std::unordered_map>>> input_shape_ranges; std::mutex* tensorrt_mu_ptr = nullptr; - std::string trt_node_name_with_precision; bool engine_cache_enable = false; std::string engine_cache_path; nvinfer1::IRuntime* runtime = nullptr; @@ -183,6 +217,7 @@ struct TensorrtFuncState { bool is_dynamic_shape = false; std::string cache_prefix; std::string cache_suffix; + // runtime parameters std::vector> scratch_buffers; std::vector input_tensors; std::vector output_tensors; @@ -204,6 +239,7 @@ struct TensorrtShortFuncState { std::vector> output_info; std::mutex* tensorrt_mu_ptr = nullptr; bool is_dynamic_shape = false; + // runtime parameters std::vector> scratch_buffers; std::vector input_tensors; std::vector output_tensors; @@ -275,14 +311,16 @@ class NvExecutionProvider : public IExecutionProvider { static common::Status RefitEngine(std::string onnx_model_filename, std::string& onnx_model_folder_path, - std::string& weight_stripped_engine_cath_path, bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, - bool serialize_refitted_engine, bool detailed_build_log); + const InlinedVector GetEpContextNodes() const override; + private: mutable NvExecutionProviderInfo info_; bool external_stream_ = false; @@ -299,6 +337,9 @@ class NvExecutionProvider : public IExecutionProvider { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + bool use_external_data_initializer_ = false; + const void* onnx_external_data_bytestream_ = nullptr; + size_t onnx_external_data_bytestream_size_ = 0; bool sparsity_enable_ = false; int auxiliary_streams_ = -1; std::string cache_path_, engine_decryption_lib_path_; @@ -317,6 +358,7 @@ class NvExecutionProvider : public IExecutionProvider { std::string cache_prefix_; std::string op_types_to_exclude_; int nv_profile_index_ = 0; + std::unique_ptr ep_context_model_; // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH int32_t trt_version_; @@ -331,7 +373,6 @@ class NvExecutionProvider : public IExecutionProvider { std::string ep_context_file_path_; int ep_context_embed_mode_ = 0; std::string ctx_model_path_; - std::string ep_cache_context_attr_; std::string engine_cache_relative_path_to_context_model_dir; std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; @@ -550,6 +591,7 @@ class NvExecutionProvider : public IExecutionProvider { * going through the time-consuming processes of model parsing and engine building. */ Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer, + size_t node_idx, const Node& fused_node, std::unordered_map& input_map, std::unordered_map& output_map, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index f90bf24ef4975..527a37f6c2b57 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -17,6 +17,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi NvExecutionProviderInfo info{}; void* user_compute_stream = nullptr; void* onnx_bytestream = nullptr; + void* external_data_bytestream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -48,21 +49,14 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) .AddAssignmentToReference(nv::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer) .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) - .AddValueParser( - nv::provider_option_names::kONNXBytestream, - [&onnx_bytestream](const std::string& value_str) -> Status { - size_t address; - ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); - onnx_bytestream = reinterpret_cast(address); - return Status::OK(); - }) - .AddAssignmentToReference(nv::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); info.onnx_bytestream = onnx_bytestream; + info.external_data_bytestream = external_data_bytestream; // EP context settings // when EP context is enabled, default is to embed the engine in the context model @@ -73,7 +67,8 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.dump_ep_context_model = false; } else if (ep_context_enable == "1") { info.dump_ep_context_model = true; - info.weight_stripped_engine_enable = true; + // We want to reenable weightless engines as soon constant initializers are supported as inputs + info.weight_stripped_engine_enable = false; } else { ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1"); } @@ -110,9 +105,7 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv {nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, - {nv::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, - {nv::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, - }; + {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}}; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 4d6c6fe116076..b826925361b05 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -31,6 +31,9 @@ struct NvExecutionProviderInfo { std::string onnx_model_folder_path{""}; const void* onnx_bytestream{nullptr}; size_t onnx_bytestream_size{0}; + bool use_external_data_initializer{false}; + const void* external_data_bytestream{nullptr}; + size_t external_data_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index ea586ba445ba2..c564fe65c3d5c 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -386,22 +386,11 @@ std::string GetCachePath(const std::string& root, const std::string& name) { * Get compute capability * */ -std::string GetComputeCapacity(const cudaDeviceProp& prop) { +std::string GetComputeCapability(const cudaDeviceProp& prop) { const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); return compute_capability; } -/* - * Get Timing by compute capability - * - */ -std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { - // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache - const std::string timing_cache_name = "NvExecutionProvider_cache_sm" + - compute_cap + ".timing"; - return GetCachePath(root, timing_cache_name); -} - /* * Get cache by type * diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 21d964b0c341f..1f34a0f25877d 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -20,10 +20,11 @@ extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); * * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc */ -bool GraphHasCtxNode(const GraphViewer& graph_viewer) { +bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx) { for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) { auto node = graph_viewer.GetNode(i); if (node != nullptr && node->OpType() == EPCONTEXT_OP) { + node_idx = i; return true; } } @@ -63,19 +64,18 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, } /* - * Create "EP context node" model where engine information is embedded + * Create EP context node where engine information is embedded */ -ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - const std::string compute_capability, - const std::string onnx_model_path, - const logging::Logger* logger) { - auto model_build = graph_viewer.CreateModel(*logger); - auto& graph_build = model_build->MainGraph(); - +Status CreateCtxNode(const GraphViewer& graph_viewer, + Graph& graph_build, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string compute_capability, + const std::string onnx_model_path, + const std::string& ep_context_node_name, + int32_t trt_version) { // Get graph inputs and outputs std::vector inputs, outputs; for (auto input : graph_viewer.GetInputs()) { @@ -89,55 +89,71 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, } // Create EP context node attributes - auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode - auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context - auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture - auto attr_3 = ONNX_NAMESPACE::AttributeProto::Create(); // onnx_model_filename + auto attr_embed_mode = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_main_context = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_ep_cache_context = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_sdk_version = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_hw_architecture = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_onnx_filename = ONNX_NAMESPACE::AttributeProto::Create(); + auto attr_partition_name = ONNX_NAMESPACE::AttributeProto::Create(); std::string engine_data_str = ""; - attr_0->set_name(EMBED_MODE); - attr_0->set_type(onnx::AttributeProto_AttributeType_INT); - attr_0->set_i(embed_mode); - attr_1->set_name(EP_CACHE_CONTEXT); - attr_1->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_main_context->set_name(MAIN_CONTEXT); + attr_main_context->set_type(onnx::AttributeProto_AttributeType_INT); + attr_main_context->set_i(0); // we do not support a main context node but each has it's own engine payload + attr_embed_mode->set_name(EMBED_MODE); + attr_embed_mode->set_type(onnx::AttributeProto_AttributeType_INT); + attr_embed_mode->set_i(embed_mode); + attr_ep_cache_context->set_name(EP_CACHE_CONTEXT); + attr_ep_cache_context->set_type(onnx::AttributeProto_AttributeType_STRING); if (embed_mode) { if (size > 0) { engine_data_str.assign(engine_data, size); } - attr_1->set_s(engine_data_str); - // TODO(maximilianm) we might want to disable this warning as we only support weightless engines that are really small - // the reason we had this was that the field will be hashed and storing a large bytestream has significant overhead - LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + attr_ep_cache_context->set_s(engine_data_str); } else { - attr_1->set_s(engine_cache_path); + std::string engine_cache_filename = std::filesystem::path(engine_cache_path).filename().string(); + attr_ep_cache_context->set_s(engine_cache_filename); + std::fstream engine_cache_file(engine_cache_path, std::ios::binary | std::ios::out); + if (engine_cache_file.is_open()) { + engine_cache_file.write(engine_data, size); + engine_cache_file.close(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "NvTensorRTRTX EP could not write cache to ", engine_cache_path); + } } - attr_2->set_name(COMPUTE_CAPABILITY); - attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_2->set_s(compute_capability); - attr_3->set_name(ONNX_MODEL_FILENAME); - attr_3->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_3->set_s(std::filesystem::path(onnx_model_path).filename().string()); + + attr_hw_architecture->set_name(COMPUTE_CAPABILITY); + attr_hw_architecture->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_hw_architecture->set_s(compute_capability); + + attr_partition_name->set_name(PARTITION_NAME); + attr_partition_name->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_partition_name->set_s(ep_context_node_name); // includes hash of the subgraph that was built + + attr_onnx_filename->set_name(ONNX_MODEL_FILENAME); + attr_onnx_filename->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_onnx_filename->set_s(std::filesystem::path(onnx_model_path).filename().string()); + + attr_sdk_version->set_name(SDK_VERSION); + attr_sdk_version->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_sdk_version->set_s(std::to_string(trt_version)); auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); constexpr int num_attributes = 4; node_attributes->reserve(num_attributes); - node_attributes->emplace(EMBED_MODE, *attr_0); - node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); - node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); - node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3); + node_attributes->emplace(MAIN_CONTEXT, *attr_main_context); + node_attributes->emplace(EMBED_MODE, *attr_embed_mode); + node_attributes->emplace(EP_CACHE_CONTEXT, *attr_ep_cache_context); + node_attributes->emplace(COMPUTE_CAPABILITY, *attr_hw_architecture); + node_attributes->emplace(PARTITION_NAME, *attr_partition_name); + node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_onnx_filename); + node_attributes->emplace(SDK_VERSION, *attr_sdk_version); // Create EP context node - graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); + graph_build.AddNode(ep_context_node_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); ORT_ENFORCE(graph_build.Resolve().IsOK()); - - // Serialize modelproto to string - auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); - auto model = new_graph_viewer->CreateModel(*logger, metadata); - auto model_proto = model->ToProto(); - new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - return model_proto.release(); + return Status::OK(); } /* @@ -206,17 +222,6 @@ std::string GetCtxModelPath(const std::string& ep_context_file_path, return ctx_model_path; } -/* - * Dump "EP context" model - * - */ -void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string& ctx_model_path) { - std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(dump); - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] Dumped " + ctx_model_path; -} - bool IsAbsolutePath(const std::string& path_string) { #ifdef _WIN32 onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); @@ -248,38 +253,12 @@ bool IsRelativePathToParentPath(const std::string& path_string) { #endif } -/* - * Get the weight-refitted engine cache path from a weight-stripped engine cache path - * - * Weight-stipped engine: - * An engine with weights stripped and its size is smaller than a regualr engine. - * The cache name of weight-stripped engine is NvExecutionProvider_TRTKernel_XXXXX.stripped.engine - * - * Weight-refitted engine: - * An engine that its weights have been refitted and it's simply a regular engine. - * The cache name of weight-refitted engine is NvExecutionProvider_TRTKernel_XXXXX.engine - */ -std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { - std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); - std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; - return refitted_engine_cache_path; -} - -bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { - // The weight-stripped engine cache has the naming of xxx.stripped.engine - return engine_cache_path.stem().extension().string() == ".stripped"; -} - -Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) { - if (!ValidateEPCtxNode(graph_viewer)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node"); - } - auto node = graph_viewer.GetNode(0); - auto& attrs = node->GetAttributes(); +Status TensorRTCacheModelHandler::GetEpContextFromGraph(const Node& node) { + auto& attrs = node.GetAttributes(); const int64_t embed_mode = attrs.at(EMBED_MODE).i(); // Only make path checks if model not provided as byte buffer - bool make_secure_path_checks = !GetModelPath(graph_viewer).empty(); + bool make_secure_path_checks = ep_context_model_path_.empty(); if (embed_mode) { // Get engine from byte stream. @@ -294,15 +273,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph if (weight_stripped_engine_refit_) { const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s(); - std::string placeholder; auto status = NvExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, - placeholder, make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), - false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -327,21 +305,6 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph auto engine_cache_path = ctx_model_dir.append(cache_path); LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); - // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled - if (!weight_stripped_engine_refit_) { - weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); - } - - // If the serialized refitted engine is present, use it directly without refitting the engine again - if (weight_stripped_engine_refit_) { - const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); - if (std::filesystem::exists(refitted_engine_cache_path)) { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] " + refitted_engine_cache_path.string() + " exists."; - engine_cache_path = refitted_engine_cache_path.string(); - weight_stripped_engine_refit_ = false; - } - } - if (!std::filesystem::exists(engine_cache_path)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP can't find engine cache: " + engine_cache_path.string() + @@ -366,12 +329,12 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph std::string weight_stripped_engine_cache = engine_cache_path.string(); auto status = NvExecutionProvider::RefitEngine(onnx_model_filename, onnx_model_folder_path_, - weight_stripped_engine_cache, make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), - true /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -384,11 +347,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph /* * The sanity check for EP context contrib op. */ -bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) { - assert(graph_viewer.NumberOfNodes() == 1); - assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP); - auto node = graph_viewer.GetNode(0); - auto& attrs = node->GetAttributes(); +bool TensorRTCacheModelHandler::ValidateEPCtxNode(const Node& node) { + auto& attrs = node.GetAttributes(); // Show the warning if compute capability is not matched if (attrs.count(COMPUTE_CAPABILITY) > 0) { @@ -413,7 +373,7 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe const int64_t embed_mode = attrs.at(EMBED_MODE).i(); if (embed_mode == 1) { // engine binary data - LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; + // LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } return true; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index f0a05c42414e5..7c52f26cc9177 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "core/providers/nv_tensorrt_rtx/nv_includes.h" #include "core/providers/shared_library/provider_api.h" @@ -14,33 +15,32 @@ namespace onnxruntime { static const std::string EPCONTEXT_OP = "EPContext"; +static const std::string MAIN_CONTEXT = "main_context"; static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; static const std::string ONNX_MODEL_FILENAME = "onnx_model_filename"; +static const std::string PARTITION_NAME = "partition_name"; +static const std::string SDK_VERSION = "ep_sdk_version"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; -static const std::string EPCONTEXT_WARNING = - "It's suggested to set the ORT graph optimization level to 0 and \ - make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ - for the best model loading time"; -bool GraphHasCtxNode(const GraphViewer& graph_viewer); +bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx); const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); -ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - const std::string compute_capability, - const std::string onnx_model_path, - const logging::Logger* logger); +Status CreateCtxNode(const GraphViewer& graph_viewer, + Graph& graph_build, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string compute_capability, + const std::string onnx_model_path, + const std::string& ep_context_node_name, + int trt_version); std::string GetCtxModelPath(const std::string& ep_context_file_path, const std::string& original_model_path); bool IsAbsolutePath(const std::string& path_string); bool IsRelativePathToParentPath(const std::string& path_string); -void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string& ctx_model_path); void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, char* engine_data, size_t size); @@ -55,6 +55,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, bool detailed_build_log) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), @@ -64,13 +66,15 @@ class TensorRTCacheModelHandler { onnx_model_folder_path_(onnx_model_folder_path), onnx_model_bytestream_(onnx_model_bytestream), onnx_model_bytestream_size_(onnx_model_bytestream_size), + onnx_external_data_bytestream_(onnx_external_data_bytestream), + onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size), detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); - bool ValidateEPCtxNode(const GraphViewer& graph_viewer); + bool ValidateEPCtxNode(const Node& node); - Status GetEpContextFromGraph(const GraphViewer& graph_viewer); + Status GetEpContextFromGraph(const Node& node); private: std::unique_ptr* trt_engine_; @@ -81,6 +85,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_; + size_t onnx_external_data_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 0152ad27c0ba2..d99e322641199 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -273,11 +273,10 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, // Check if we need to add a cast node for int64 bool needs_int64_cast = false; if (is_graph_output) { - for (const auto& input_name : input_names) { - if (input_name.find("_cast_int32") != std::string::npos) { - needs_int64_cast = true; - break; - } + if (supported_qnn_data_type == output_info.qnn_data_type && + (output_info.qnn_data_type == QNN_DATATYPE_INT_64 || output_info.qnn_data_type == QNN_DATATYPE_UINT_64)) { + supported_qnn_data_type = supported_qnn_data_type == QNN_DATATYPE_INT_64 ? QNN_DATATYPE_INT_32 : QNN_DATATYPE_UINT_32; + needs_int64_cast = true; } } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 21947a22e2b92..78b16ed784049 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -93,15 +93,16 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -static std::vector AmendOutputShapeForRank3Pool( +Status AmendOutputShapeForRank3Pool( gsl::span input_shape, // {N, H, W, C} gsl::span kernel_shape, // {k_h, k_w} gsl::span strides, // {s_h, s_w} - gsl::span pads) { - assert(input_shape.size() == 4 && - kernel_shape.size() == 2 && - strides.size() == 2 && - pads.size() == 4); + gsl::span pads, + std::vector& output_shape) { + ORT_RETURN_IF_NOT(input_shape.size() == 4, "Expecting input rank 4 for amending 1D Pool output shape."); + ORT_RETURN_IF_NOT(kernel_shape.size() == 2, "Expecting kernel size 2 for amending 1D Pool output shape."); + ORT_RETURN_IF_NOT(strides.size() == 2, "Expecting strides size 2 for amending 1D Pool output shape."); + ORT_RETURN_IF_NOT(pads.size() == 4, "Expecting pad size 4 for amending 1D Pool output shape."); const uint32_t N = input_shape[0]; const uint32_t H = input_shape[1]; @@ -120,7 +121,13 @@ static std::vector AmendOutputShapeForRank3Pool( ? 0 : (padded_W - kernel_shape[1]) / strides[1] + 1; - return {N, out_H, out_W, C}; + output_shape.resize(4); + output_shape[0] = N; + output_shape[1] = out_H; + output_shape[2] = out_W; + output_shape[3] = C; + + return Status::OK(); } Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, @@ -177,10 +184,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, if (auto_pad.compare("NOTSET") != 0) { if (output_shape.size() == 3) { // Calculate rank-4 output shape for rank-3 input. - output_shape = AmendOutputShapeForRank3Pool(input_shape, - filter_size, - stride, - pad_amount); + ORT_RETURN_IF_ERROR(AmendOutputShapeForRank3Pool(input_shape, filter_size, stride, pad_amount, output_shape)); } for (size_t axis = 0; axis < rank - 2; ++axis) { @@ -365,14 +369,6 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::move(output_shape))); } - // Calculate rank-4 output shape for rank-3 input. - std::vector onnx_in_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape"); - if (onnx_in_shape.size() == 3) { - onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]}; - } - auto pooled_shape = AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount); - // Construct param wrappers. ORT_RETURN_IF_NOT(SetPoolParam(node_unit, param_filter_size, @@ -443,6 +439,16 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra return Status::OK(); } + + // Calculate rank-4 output shape for rank-3 input. + std::vector onnx_in_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape"); + if (onnx_in_shape.size() == 3) { + onnx_in_shape = {onnx_in_shape[0], 1, onnx_in_shape[1], onnx_in_shape[2]}; + } + std::vector pooled_shape; + ORT_RETURN_IF_ERROR(AmendOutputShapeForRank3Pool(onnx_in_shape, filter_size, stride, pad_amount, pooled_shape)); + const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); const std::string pool_out = real_out + "_reshape_after"; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 3dc103046424e..5bcb8ca394346 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -787,10 +787,12 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord std::vector context_params_list; std::vector context_paramsv1_list; - std::vector context_params_ptr_list(context_bin_map.size() + 1); + std::vector context_params_ptr_list; std::vector> buffer_list; - size_t idx = 0; + context_params_list.reserve(context_bin_map.size()); + context_params_ptr_list.reserve(context_bin_map.size() + 1); + for (auto& it : context_bin_map) { auto context_bin_filepath = it.first; @@ -821,9 +823,9 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord buffer_list.push_back(std::move(buffer)); context_params_list.push_back(std::move(context_params)); context_paramsv1_list.push_back(std::move(context_params_v1)); - context_params_ptr_list[idx++] = &context_params_list.back(); + context_params_ptr_list.push_back(&context_params_list.back()); } - context_params_ptr_list[idx] = nullptr; + context_params_ptr_list.push_back(nullptr); auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_, device_handle_, context_params_ptr_list.data(), @@ -1178,6 +1180,14 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) if (vtcm_backup_buffer_sharing_enabled_) { + // If a context bin filepath has not been processed yet, + // then a new context must be created for the set of context bins + auto first_mapping_it = ep_context_handle_map_.find(context_bin_map.begin()->first); + if (first_mapping_it == ep_context_handle_map_.end()) { + LOGS(logger, VERBOSE) << "Creating context for new set of context binaries"; + return CreateContextVtcmBackupBufferSharingEnabled(context_bin_map); + } + LOGS(logger, VERBOSE) << "Mapping contexts to new EP main context nodes"; for (auto& it : context_bin_map) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5c9c1a0ae163f..9a0bcb53c9ad7 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1011,6 +1011,8 @@ struct ProviderHost { virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0; // We pass OrtValue by reference here (as opposed to the original Graph function) to avoid header inclusion virtual Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) = 0; + virtual bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value, + bool check_outer_scope) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, NodeAttributes&& attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const Node& other) = 0; @@ -1074,6 +1076,8 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::TensorProto* GraphViewer__GetConstantInitializer(const GraphViewer* p, const std::string& name, bool check_outer_scope) const = 0; + virtual bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name, + OrtValue& value) = 0; virtual const Node* GraphViewer__ParentNode(const GraphViewer* p) = 0; virtual int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept = 0; virtual int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 23fbead1e9707..19b4636c3766d 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1041,6 +1041,10 @@ struct Graph final { Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& ort_value) { return g_host->Graph__AddInitializedOrtValue(this, tensor, ort_value); } + bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value, + bool check_outer_scope = false) const { + return g_host->Graph__GetOrtValueInitializer(this, tensor_name, ort_value, check_outer_scope); + } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, NodeAttributes&& attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, std::move(attributes), domain); } Node& AddNode(const Node& other) { return g_host->Graph__AddNode(this, other); } @@ -1124,6 +1128,9 @@ class GraphViewer final { bool check_outer_scope = true) const { return g_host->GraphViewer__GetConstantInitializer(this, name, check_outer_scope); } + bool GetOrtValueInitializer(const std::string& tensor_name, OrtValue& ort_value) const { + return g_host->GraphViewer__GetOrtValueInitializer(this, tensor_name, ort_value); + } const Node* ParentNode() const { return g_host->GraphViewer__ParentNode(this); } int NumberOfNodes() const noexcept { return g_host->GraphViewer__NumberOfNodes(this); } diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 7d739439b7a27..72d5007b84e6f 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -18,11 +18,11 @@ struct OrtKeyValuePairs { CopyFromMap(other.entries_); } - OrtKeyValuePairs(OrtKeyValuePairs&& other) : OrtKeyValuePairs{} { + OrtKeyValuePairs(OrtKeyValuePairs&& other) noexcept : OrtKeyValuePairs{} { swap(*this, other); } - OrtKeyValuePairs& operator=(OrtKeyValuePairs other) { // handles copy and move assignment + OrtKeyValuePairs& operator=(OrtKeyValuePairs other) noexcept { // handles copy and move assignment swap(*this, other); return *this; } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index dfb2e33f8cb32..39b785c327d56 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -182,11 +182,6 @@ Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool shared_ort_allocators_.erase(it2); } - // also remove an arena wrapped allocator from an EP if the user called CreateSharedAllocator to create one - if (auto it3 = arena_ort_allocators_.find(&mem_info); it3 != arena_ort_allocators_.end()) { - arena_ort_allocators_.erase(it3); - } - if (found_shared_allocator) { shared_allocators_.erase(it); } @@ -436,6 +431,10 @@ Environment::~Environment() { // instance and will call Release on it. If the plugin EP has been freed the Release will fail. shared_allocators_.clear(); + // and as any OrtAllocator instances in shared_ort_allocators_ were owned by values in shared_allocators_ and have + // now been released we need to clear that too before calling UnregisterExecutionProviderLibrary(). + shared_ort_allocators_.clear(); + #if !defined(ORT_MINIMAL_BUILD) // unregister any remaining EP libraries so they're cleaned up in a determistic way. while (!ep_libraries_.empty()) { @@ -673,11 +672,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, shared_ort_allocators_.erase(it); } - // if a previous call created an arena wrapped allocator for the EP's memory_info we also need to remove that - if (auto it = arena_ort_allocators_.find(&memory_info); it != arena_ort_allocators_.end()) { - arena_ort_allocators_.erase(it); - } - // we only want one shared allocator for an OrtDevice in the shared_allocators_ so that it's deterministic which // one will be used for an inference session. ignore the name so that is the case. if (auto it = FindExistingAllocator(shared_allocators_, memory_info, /*match_name*/ false); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 88d84e95b406c..1f491bc788870 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1094,7 +1094,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorMutableData, _Inout_ OrtValue* value, _Out API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _Inout_ const OrtValue* value, _Outptr_ const void** output) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** output) { TENSOR_READ_API_BEGIN *output = tensor.DataRaw(); return nullptr; @@ -2626,6 +2626,16 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + *out = reinterpret_cast(graph->GetModelMetadata().release()); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path) { API_IMPL_BEGIN if (model_path == nullptr) { @@ -2761,7 +2771,8 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetNodes, _In_ const OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_ const OrtNode** node) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, + _Outptr_result_maybenull_ const OrtNode** node) { API_IMPL_BEGIN if (node == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node' argument is NULL"); @@ -3209,7 +3220,7 @@ ORT_API(void, OrtApis::GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, *num_entries = kvps->Entries().size(); } -ORT_API(void, OrtApis::RemoveKeyValuePair, _Frees_ptr_opt_ OrtKeyValuePairs* kvps, _In_ const char* key) { +ORT_API(void, OrtApis::RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key) { kvps->Remove(key); } @@ -3218,7 +3229,7 @@ ORT_API(void, OrtApis::ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs* k } #if !defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path) { API_IMPL_BEGIN ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().RegisterExecutionProviderLibrary(registration_name, path)); @@ -3226,7 +3237,7 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name) { +ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name) { API_IMPL_BEGIN ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().UnregisterExecutionProviderLibrary(registration_name)); return nullptr; @@ -3413,7 +3424,7 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, } #else // defined(ORT_MINIMAL_BUILD) -ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); @@ -3421,7 +3432,7 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*en } ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, - const char* /*registration_name*/) { + _In_ const char* /*registration_name*/) { API_IMPL_BEGIN return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); API_IMPL_END @@ -4095,6 +4106,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ReleaseSyncStream, &OrtApis::CopyTensors, + + &OrtApis::Graph_GetModelMetadata, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3eee174ff81f4..b3b0036c68247 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -562,7 +562,7 @@ ORT_API(void, GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, ORT_API(void, RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key); ORT_API(void, ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs*); -ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path); ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); @@ -635,6 +635,7 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); +ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); @@ -652,7 +653,7 @@ ORT_API_STATUS_IMPL(Graph_GetInitializers, _In_ const OrtGraph* graph, _Out_writes_(num_initializers) const OrtValueInfo** initializers, _In_ size_t num_initializers); ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* num_nodes); -ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, +ORT_API_STATUS_IMPL(Graph_GetNodes, _In_ const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 3610b0f797a46..f3e30caf07e81 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -23,6 +23,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; + OrtEpFactory::ValidateCompiledModelCompatibilityInfo = Forward::ValidateCompiledModelCompatibilityInfo; OrtEpFactory::CreateAllocator = Forward::CreateAllocator; OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator; OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 0e34fef0ff74c..23e5e95af2903 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -80,6 +80,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } + OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index bd0b76b21511f..6c55730d83979 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -62,6 +62,14 @@ class EpFactoryInternalImpl { return false; } + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + ORT_UNUSED_PARAMETER(compatibility_info); + // Default implementation: mark as not applicable + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 2aac1e1c21cc7..3bfca62a4d011 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -644,4 +644,35 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost); } } + +std::string PluginExecutionProvider::GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const { + if (ort_ep_->GetCompiledModelCompatibilityInfo == nullptr) { + // Plugin EP did not provide an implementation of this function, so we call a default implementation. + return Base::GetCompiledModelCompatibilityInfo(graph_viewer); + } + std::unique_ptr ep_graph = nullptr; + auto ort_status = EpGraph::Create(graph_viewer, ep_graph); + if (!ort_status.IsOK()) { + LOGS(*GetLogger(), ERROR) << "Failed to create EpGraph: " << ort_status.ToString(); + return {}; + } + // Call EP plugin's OrtEp::GenerateCompiledModelCompatibilityInfo() function. + std::string compatibility_info_string; + compatibility_info_string = ort_ep_->GetCompiledModelCompatibilityInfo(ort_ep_.get(), ep_graph.get()); + return compatibility_info_string; +} + +Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const { + if (ep_factory_.ValidateCompiledModelCompatibilityInfo == nullptr) { + // Plugin EP did not provide an implementation of this function, so we call a default implementation. + return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + // Delegate to the EP factory's validation method + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_, + compatibility_info.c_str(), + &model_compatibility))); + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 728f959ad67cb..622bbb3f97b24 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -101,6 +101,11 @@ class PluginExecutionProvider : public IExecutionProvider { // needed based on matching against allocator_mem_infos_. std::vector CreatePreferredAllocators() override; + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; + + Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override; + private: struct FusedNodeState { FusedNodeState() = default; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 67b22779395ec..29793b503c9d1 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -45,6 +45,12 @@ struct ForwardToFactoryImpl { session_options, logger, ep); } + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, _In_ const OrtMemoryInfo* memory_info, _In_opt_ const OrtKeyValuePairs* allocator_options, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ee59ff2ab4932..41cf8be1d1412 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1258,6 +1258,10 @@ struct ProviderHostImpl : ProviderHost { void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) override { p->AddInitializedTensor(tensor); } Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) override { return p->AddInitializedOrtValue(tensor, value); } + bool Graph__GetOrtValueInitializer(const Graph* p, const std::string& tensor_name, OrtValue& value, + bool check_outer_scope) override { + return p->GetOrtValueInitializer(tensor_name, value, check_outer_scope); + } Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) override { return p->AddNode(name, op_type, description, input_args, output_args, attributes, domain); } @@ -1356,6 +1360,10 @@ struct ProviderHostImpl : ProviderHost { bool check_outer_scope) const override { return p->GetConstantInitializer(name, check_outer_scope); } + bool GraphViewer__GetOrtValueInitializer(const GraphViewer* p, const std::string& tensor_name, + OrtValue& value) override { + return p->GetOrtValueInitializer(tensor_name, value); + } const Node* GraphViewer__ParentNode(const GraphViewer* p) override { return p->ParentNode(); } int GraphViewer__NumberOfNodes(const GraphViewer* p) noexcept override { return p->NumberOfNodes(); } int GraphViewer__MaxNodeIndex(const GraphViewer* p) noexcept override { return p->MaxNodeIndex(); } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index f90ace95d6e58..d4041dfce5a7a 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -17,6 +17,7 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/session/plugin_ep/ep_factory_internal.h" @@ -206,6 +207,117 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, return CreateSessionAndLoadModelImpl(options, env->GetEnvironment(), model_path, model_data, model_data_length, sess); } +#if !defined(ORT_MINIMAL_BUILD) +static const char* GetCompatibilityStatusString(OrtCompiledModelCompatibility status) { + switch (status) { + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + return "SUPPORTED_OPTIMAL"; + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + return "SUPPORTED_PREFER_RECOMPILATION"; + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + return "UNSUPPORTED"; + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + return "NOT_APPLICABLE"; + default: + return "UNKNOWN"; + } +} + +static Status ValidateCompiledModelCompatibility(InferenceSession& sess) { + // Get model metadata + auto [status, model_metadata] = sess.GetModelMetadata(); + if (!status.IsOK() || !model_metadata) { + // No metadata available, skip validation + return Status::OK(); + } + + const auto& custom_metadata = model_metadata->custom_metadata_map; + if (custom_metadata.empty()) { + // No custom metadata available, skip validation + return Status::OK(); + } + + // Check if user wants to fail on suboptimal models + bool fail_on_suboptimal = sess.GetSessionOptions().config_options.GetConfigEntry( + kOrtSessionOptionsFailOnSuboptimalCompiledModel) == "1"; + + const auto& registered_provider_types = sess.GetRegisteredProviderTypes(); + + // Access the execution providers through the session state (available after Initialize) + const auto& execution_providers = sess.GetSessionState().GetExecutionProviders(); + + for (const auto& ep_type : registered_provider_types) { + // Construct the full metadata key using the prefix + EP type + const std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + auto metadata_it = custom_metadata.find(metadata_key); + if (metadata_it != custom_metadata.end()) { + const std::string& compatibility_info = metadata_it->second; + + // Get the actual EP instance to call validation + const IExecutionProvider* ep = execution_providers.Get(ep_type); + + if (ep != nullptr) { + // Call the EP's validation method (virtual method with default implementation) + OrtCompiledModelCompatibility compatibility_status; + Status validation_result = ep->ValidateCompiledModelCompatibilityInfo( + compatibility_info, compatibility_status); + + if (validation_result.IsOK()) { + // Log the compatibility status + const char* status_str = GetCompatibilityStatusString(compatibility_status); + LOGS(*sess.GetLogger(), INFO) + << "EP " << ep_type << " compiled model compatibility: " << status_str; + + // Enforce compatibility based on status + switch (compatibility_status) { + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + // Continue execution + break; + + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + // Always fail for unsupported models + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Compiled model is not supported by execution provider: " + ep_type); + + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + // Behavior depends on user setting + if (fail_on_suboptimal) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Compiled model is suboptimal for execution provider: " + ep_type + + ". Recompilation recommended for better performance."); + } + // Otherwise continue with warning + LOGS(*sess.GetLogger(), WARNING) + << "EP " << ep_type << " reports compiled model is supported but suboptimal. " + << "Consider recompiling for better performance."; + break; + + default: + // Handle any unknown status values + LOGS(*sess.GetLogger(), WARNING) + << "EP " << ep_type << " returned unknown compatibility status: " << compatibility_status; + break; + } + } else { + // Validation failed - this should cause session initialization to fail + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to validate compiled model compatibility for EP " + ep_type + + ": " + validation_result.ErrorMessage()); + } + } + } else { + // No compatibility info found for this EP - normal for non-compiled models + LOGS(*sess.GetLogger(), VERBOSE) + << "No compiled model compatibility info found for EP " << ep_type; + } + } + + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { @@ -253,6 +365,12 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); +#if !defined(ORT_MINIMAL_BUILD) + // Validate compiled model compatibility for all registered execution providers + // This must be done after Initialize() so the session state is available + ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess)); +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } diff --git a/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp b/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp index ec30b8ba0985d..2550fde338cd5 100644 --- a/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp +++ b/onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp @@ -10,21 +10,15 @@ void run_ort_trt2() { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); - session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); const char* model_path = "squeezenet.onnx"; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), - rel_trt_options.get())); + Ort::TensorRTProviderOptions tensorrt_options; + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); std::cout << "Running ORT TRT EP with default provider options" << std::endl; @@ -127,7 +121,7 @@ void run_ort_trt2() { void run_ort_trt() { Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; + Ort::TensorRTProviderOptions tensorrt_options; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); @@ -136,11 +130,7 @@ void run_ort_trt() { const char* model_path = "/data/ep-perf-models/onnx-zoo-models/squeezenet1.0-7/squeezenet/model.onnx"; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), - rel_trt_options.get())); + session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); std::cout << "Running ORT TRT EP with default provider options" << std::endl; diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 77d2bb24b7d35..88b522eb10dca 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -60,66 +60,58 @@ struct DummyAllocator : OrtAllocator { // validate CreateSharedAllocator allows adding an arena to the shared allocator TEST(SharedAllocators, AddArenaToSharedAllocator) { - const OrtApi& c_api = Ort::GetApi(); RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(example_ep.get(), OrtDeviceMemoryType_DEFAULT); + Ort::ConstEpDevice example_ep_device{example_ep.get()}; + + auto ep_memory_info = example_ep_device.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); // validate there is a shared allocator - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(ep_memory_info); ASSERT_NE(allocator, nullptr); // call CreateSharedAllocator to replace with arena based allocator. arena is configured with kvps - OrtKeyValuePairs allocator_options; + Ort::KeyValuePairs allocator_options; auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); - ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT, - // allocator is internally added by EP. - // OrtArenaAllocator can only be used for the internal BFCArena - OrtDeviceAllocator, - &allocator_options, &allocator)); + allocator = ort_env->CreateSharedAllocator(example_ep.get(), OrtDeviceMemoryType_DEFAULT, + // allocator is internally added by EP. + // OrtArenaAllocator can only be used for the internal BFCArena + OrtDeviceAllocator, + allocator_options); // first allocation should init the arena to the initial chunk size - void* mem = allocator->Alloc(allocator, 16); - allocator->Free(allocator, mem); + void* mem = allocator.Alloc(16); + allocator.Free(mem); // stats should prove the arena was used - OrtKeyValuePairs* allocator_stats = nullptr; - ASSERT_ORTSTATUS_OK(allocator->GetStats(allocator, &allocator_stats)); + auto allocator_stats = allocator.GetStats(); using ::testing::Contains; using ::testing::Pair; - const auto& stats = allocator_stats->Entries(); + const auto& stats = static_cast(allocator_stats)->Entries(); EXPECT_THAT(stats, Contains(Pair("NumAllocs", "1"))); EXPECT_THAT(stats, Contains(Pair("NumArenaExtensions", "1"))); EXPECT_THAT(stats, Contains(Pair("TotalAllocated", initial_chunk_size))); // optional. ORT owns the allocator but we want to test the release implementation - ASSERT_ORTSTATUS_OK(c_api.ReleaseSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT)); + ort_env->ReleaseSharedAllocator(example_ep.get(), OrtDeviceMemoryType_DEFAULT); } TEST(SharedAllocators, GetSharedAllocator) { - const OrtApi& c_api = Ort::GetApi(); - // default CPU allocator should be available. // create a memory info with a different name to validate the shared allocator lookup ignores the name - OrtMemoryInfo* test_cpu_memory_info = nullptr; - ASSERT_ORTSTATUS_OK(c_api.CreateMemoryInfo_V2("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, - OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator, - &test_cpu_memory_info)); + auto test_cpu_memory_info = Ort::MemoryInfo("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, + OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator); const auto get_allocator_and_check_name = [&](const std::string& expected_name) { - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, test_cpu_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(test_cpu_memory_info); ASSERT_NE(allocator, nullptr); - const OrtMemoryInfo* ort_cpu_memory_info = nullptr; - ASSERT_ORTSTATUS_OK(c_api.AllocatorGetInfo(allocator, &ort_cpu_memory_info)); - const char* allocator_name; - ASSERT_ORTSTATUS_OK(c_api.MemoryInfoGetName(ort_cpu_memory_info, &allocator_name)); + auto ort_cpu_memory_info = allocator.GetInfo(); + auto allocator_name = ort_cpu_memory_info.GetAllocatorName(); ASSERT_EQ(expected_name, allocator_name); // Default ORT CPU allocator }; @@ -128,18 +120,16 @@ TEST(SharedAllocators, GetSharedAllocator) { // register custom allocator and make sure that is accessible by exact match DummyAllocator dummy_alloc{test_cpu_memory_info}; - c_api.RegisterAllocator(*ort_env, &dummy_alloc); + ort_env->RegisterAllocator(&dummy_alloc); // GetSharedAllocator should now match the custom allocator get_allocator_and_check_name("dummy"); // unregister custom allocator - ASSERT_ORTSTATUS_OK(c_api.UnregisterAllocator(*ort_env, test_cpu_memory_info)); + ort_env->UnregisterAllocator(test_cpu_memory_info); // there should always be a CPU allocator available get_allocator_and_check_name(onnxruntime::CPU); - - c_api.ReleaseMemoryInfo(test_cpu_memory_info); } } // namespace test diff --git a/onnxruntime/test/autoep/test_data_transfer.cc b/onnxruntime/test/autoep/test_data_transfer.cc index cc09699b754b6..71c69698ed386 100644 --- a/onnxruntime/test/autoep/test_data_transfer.cc +++ b/onnxruntime/test/autoep/test_data_transfer.cc @@ -23,16 +23,15 @@ namespace onnxruntime { namespace test { TEST(OrtEpLibrary, DataTransfer) { - const OrtApi& c_api = Ort::GetApi(); RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* ep_device = example_ep.get(); + Ort::ConstEpDevice ep_device(example_ep.get()); - const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + auto device_memory_info = ep_device.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); // create a tensor using the default CPU allocator Ort::AllocatorWithDefaultOptions cpu_allocator; - std::vector shape{2, 3, 4}; // shape doesn't matter + constexpr const std::array shape{2, 3, 4}; // shape doesn't matter const size_t num_elements = 2 * 3 * 4; RandomValueGenerator random{}; @@ -44,24 +43,21 @@ TEST(OrtEpLibrary, DataTransfer) { // create an on-device Tensor using the example EPs alternative CPU allocator. // it has a different vendor to the default ORT CPU allocator so we can copy between them even though both are // really CPU based. - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); ASSERT_NE(allocator, nullptr); Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); - std::vector src_tensor_ptrs{cpu_tensor}; - std::vector dst_tensor_ptrs{device_tensor}; + std::vector src_tensor; + src_tensor.push_back(std::move(cpu_tensor)); + std::vector dst_tensor; + dst_tensor.push_back(std::move(device_tensor)); - ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, - src_tensor_ptrs.size())); + ASSERT_CXX_ORTSTATUS_OK(ort_env->CopyTensors(src_tensor, dst_tensor, nullptr)); - const float* src_data = nullptr; - const float* dst_data = nullptr; - ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); - ASSERT_ORTSTATUS_OK(c_api.GetTensorData(device_tensor, reinterpret_cast(&dst_data))); + const float* src_data = src_tensor[0].GetTensorData(); + const float* dst_data = dst_tensor[0].GetTensorData(); - size_t bytes; - ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + size_t bytes = src_tensor[0].GetTensorSizeInBytes(); ASSERT_EQ(bytes, num_elements * sizeof(float)); ASSERT_NE(src_data, dst_data) << "Should have copied between two different memory locations"; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index f1ef67e1f6ba4..0f4a654f116c4 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -54,12 +54,12 @@ void RunModelWithPluginEp(Ort::SessionOptions& session_options) { TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* plugin_ep_device = example_ep.get(); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); // Create session with example plugin EP Ort::SessionOptions session_options; std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); RunModelWithPluginEp(session_options); } @@ -83,7 +83,7 @@ TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; Utils::RegisterAndGetExampleEp(*ort_env, example_ep); - const OrtEpDevice* plugin_ep_device = example_ep.get(); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); { const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); @@ -94,7 +94,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { Ort::SessionOptions session_options; std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); @@ -102,9 +102,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { compile_options.SetOutputModelPath(output_model_file); // Compile the model. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); - + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); // Make sure the compiled model was generated. ASSERT_TRUE(std::filesystem::exists(output_model_file)); } diff --git a/onnxruntime/test/common/string_utils_test.cc b/onnxruntime/test/common/string_utils_test.cc index 79f8ddff7b52a..983f7fa7a87f9 100644 --- a/onnxruntime/test/common/string_utils_test.cc +++ b/onnxruntime/test/common/string_utils_test.cc @@ -15,6 +15,8 @@ namespace test { namespace { template void TestSuccessfulParse(const std::string& input, const T& expected_value) { + SCOPED_TRACE(MakeString("Input: \"", input, "\", expected_value: ", expected_value)); + T value; ASSERT_TRUE(TryParseStringWithClassicLocale(input, value)); EXPECT_EQ(value, expected_value); @@ -22,6 +24,8 @@ void TestSuccessfulParse(const std::string& input, const T& expected_value) { template void TestFailedParse(const std::string& input) { + SCOPED_TRACE(MakeString("Input: \"", input, "\"")); + T value; EXPECT_FALSE(TryParseStringWithClassicLocale(input, value)); } @@ -31,6 +35,7 @@ TEST(StringUtilsTest, TryParseStringWithClassicLocale) { TestSuccessfulParse("-1", -1); TestSuccessfulParse("42", 42u); TestSuccessfulParse("2.5", 2.5f); + TestSuccessfulParse("0x100", uint32_t{0x100}); // out of range TestFailedParse("32768"); diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 4e8d1b9f016f0..df83815cc29ea 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -146,15 +146,10 @@ static void RunOneTest( execution_providers.push_back(DefaultRocmExecutionProvider()); } else { if (strict) { - const auto& api = Ort::GetApi(); - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_skip_layer_norm_strict_mode"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); - execution_providers.push_back(CudaExecutionProviderWithOptions(std::move(rel_cuda_options.get()))); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options = {{"enable_skip_layer_norm_strict_mode", "1"}}; + cuda_options.Update(options); + execution_providers.push_back(CudaExecutionProviderWithOptions(std::move(cuda_options))); } else { execution_providers.push_back(DefaultCudaExecutionProvider()); } diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 188edad572182..513097aaf7ade 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -914,7 +914,22 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ const ORTCHAR_T* api_model_path = nullptr; ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path)); ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str())); - + // Check the model metadata + Ort::AllocatorWithDefaultOptions default_allocator; + auto ort_cxx_graph = Ort::ConstGraph(&api_graph); + auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata(); + auto& model = graph_viewer.GetGraph().GetModel(); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphNameAllocated(default_allocator).get(), model.MainGraph().Name().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDomainAllocated(default_allocator).get(), model.Domain().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDescriptionAllocated(default_allocator).get(), model.DocString().c_str()), 0); + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphDescriptionAllocated(default_allocator).get(), model.GraphDocString().c_str()), 0); + ASSERT_EQ(ort_cxx_model_metadat.GetVersion(), model.ModelVersion()); + auto model_meta_data = model.MetaData(); + for (auto& [k, v] : model_meta_data) { + ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.LookupCustomMetadataMapAllocated(k.c_str(), default_allocator).get(), v.c_str()), 0) + << " key=" << k << "; value=" << v; + } // Check graph inputs. const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc new file mode 100644 index 0000000000000..be97cf2620881 --- /dev/null +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -0,0 +1,410 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/framework/execution_provider.h" +#include "core/framework/compute_capability.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/abi_session_options_impl.h" +#include "core/framework/error_code_helper.h" +#include "dummy_provider.h" +#include "test_utils.h" +#include "test/test_environment.h" +#include "test/providers/provider_test_utils.h" + +using namespace onnxruntime; +using namespace onnxruntime::test; + +namespace { + +// Test execution provider that extends IExecutionProvider with compatibility string functionality +class TestCompatibilityExecutionProvider : public IExecutionProvider { + public: + static constexpr const char* kTestCompatibilityExecutionProviderType = "TestCompatibilityExecutionProvider"; + + TestCompatibilityExecutionProvider() : IExecutionProvider(kTestCompatibilityExecutionProviderType) { + } + + std::shared_ptr GetKernelRegistry() const override { + return std::make_shared(); + } + + std::vector CreatePreferredAllocators() override { + return {}; + } + + // Configurable mock behavior + void SetMockCompatibilityString(const std::string& str) { + mock_compatibility_string_ = str; + } + + void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) { + mock_compatibility_status_ = status; + } + + void SetShouldFailValidation(bool should_fail) { + should_fail_validation_ = should_fail; + } + + // Override compatibility methods + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override { + ORT_UNUSED_PARAMETER(graph_viewer); + return mock_compatibility_string_; + } + + common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override { + if (should_fail_validation_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mock validation failure"); + } + + // Simple validation logic for testing + // If the mock status is explicitly set to NOT_APPLICABLE, always return that + if (mock_compatibility_status_ == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + } else if (compatibility_info.empty()) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + } else if (compatibility_info == mock_compatibility_string_) { + model_compatibility = mock_compatibility_status_; + } else { + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + } + + return Status::OK(); + } + + private: + std::string mock_compatibility_string_ = "default_test_compatibility_v1.0"; + OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + bool should_fail_validation_ = false; +}; + +// Helper class to create test models +class ModelBuilderWithCompatibility { + public: + static std::unique_ptr CreateSimpleTestModel() { + // Create a simple model with a single Add operation + std::unordered_map domain_to_version; + domain_to_version[onnxruntime::kOnnxDomain] = 7; + + auto p_model = std::make_unique("test_model", true, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + + onnxruntime::Graph& graph = p_model->MainGraph(); + + // Define tensor type + ONNX_NAMESPACE::TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + tensor_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + // Create input and output node args + auto& input_arg_a = graph.GetOrCreateNodeArg("A", &tensor_float); + auto& input_arg_b = graph.GetOrCreateNodeArg("B", &tensor_float); + auto& output_arg = graph.GetOrCreateNodeArg("C", &tensor_float); + + // Create Add node + std::vector input_defs = {&input_arg_a, &input_arg_b}; + std::vector output_defs = {&output_arg}; + graph.AddNode("add_node", "Add", "Add two tensors", input_defs, output_defs, nullptr, onnxruntime::kOnnxDomain); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + return p_model; + } + + static std::unique_ptr CreateModelWithCompatibilityMetadata( + const std::map& ep_compatibility_info) { + auto model = CreateSimpleTestModel(); + + // Add compatibility metadata + auto& metadata = model->MetaData(); + for (const auto& [ep_type, compatibility_string] : ep_compatibility_info) { + std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + metadata[metadata_key] = compatibility_string; + } + + return model; + } +}; + +// Helper class to create test sessions +class SessionBuilderWithCompatibility { + public: + static std::unique_ptr CreateTestSession(std::unique_ptr model, bool fail_on_suboptimal = false) { + SessionOptions so; + so.session_logid = "EpCompatibilityTest"; + so.session_log_verbosity_level = 1; + + if (fail_on_suboptimal) { + EXPECT_TRUE(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "1").IsOK()); + } + + // Convert Model to ModelProto and serialize + auto model_proto = model->ToProto(); + std::string model_data; + EXPECT_TRUE(model_proto.SerializeToString(&model_data)); + std::stringstream model_stream(model_data); + + // Create session with basic constructor + auto session = std::make_unique(so, GetEnvironment()); + + // Load the model from the stream and validate the status + auto load_status = session->Load(model_stream); + EXPECT_TRUE(load_status.IsOK()) << "Failed to load model: " << load_status.ErrorMessage(); + + return session; + } +}; + +// Helper function to initialize session using the proper validation pathway +Status InitializeSessionWithValidation(InferenceSession& session) { + // Create OrtSessionOptions from the session's SessionOptions to use the proper initialization path + OrtSessionOptions ort_session_options; + ort_session_options.value = session.GetSessionOptions(); + + // Call the InitializeSession function from utils.cc which includes validation + OrtStatus* ort_status = InitializeSession(&ort_session_options, session, nullptr); + + // Convert OrtStatus to Status using the proper helper function + return ToStatusAndRelease(ort_status); +} + +} // anonymous namespace + +class EpCompatibilityTest : public ::testing::Test { + protected: + void SetUp() override { + test_model_ = ModelBuilderWithCompatibility::CreateSimpleTestModel(); + } + + protected: + std::unique_ptr test_model_; +}; + +// Test basic compatibility string generation during compilation +TEST_F(EpCompatibilityTest, TestCompatibilityStringGeneration) { + const std::string expected_compatibility_string = "test_ep_v1.0_compatibility_data"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(expected_compatibility_string); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); + + // Note: In the actual implementation, we would need to trigger EP context model creation + // to see the compatibility strings stored. For now, this tests that the methods are called + // without error during session initialization. +} + +// Test compatibility string storage in model metadata +TEST_F(EpCompatibilityTest, TestCompatibilityStringStorage) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string expected_compatibility_string = "stored_compatibility_v2.0"; + + // Create model with pre-populated compatibility metadata + std::map compatibility_info = { + {ep_type, expected_compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + // Verify metadata was stored correctly + const auto& metadata = model_with_metadata->MetaData(); + std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + auto it = metadata.find(expected_key); + ASSERT_NE(it, metadata.end()) << "Expected compatibility metadata key not found: " << expected_key; + EXPECT_EQ(it->second, expected_compatibility_string); +} + +// Test multiple EPs generating different compatibility strings +TEST_F(EpCompatibilityTest, TestMultipleEpCompatibilityStrings) { + std::map compatibility_info = { + {"EP_A", "ep_a_compatibility_v1.0"}, + {"EP_B", "ep_b_compatibility_v2.1"}, + {"EP_C", "ep_c_compatibility_v1.5"}}; + + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + // Verify all compatibility strings are stored + const auto& metadata = model_with_metadata->MetaData(); + for (const auto& [ep_type, expected_string] : compatibility_info) { + std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + auto it = metadata.find(expected_key); + ASSERT_NE(it, metadata.end()) << "Expected compatibility metadata key not found: " << expected_key; + EXPECT_EQ(it->second, expected_string); + } +} + +// Test empty compatibility string handling +TEST_F(EpCompatibilityTest, TestEmptyCompatibilityString) { + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(""); // Empty string + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed even with empty compatibility string +} + +// Test compatibility validation with optimal status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Optimal) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "optimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL); + + // Create model with matching compatibility metadata + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed with optimal compatibility +} + +// Test compatibility validation with suboptimal status (default session settings) +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Suboptimal_DefaultSettings) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "suboptimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION); + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), false); // Don't fail on suboptimal + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed by default with suboptimal compatibility +} + +// Test compatibility validation with suboptimal status (fail on suboptimal enabled) +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Suboptimal_FailEnabled) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "suboptimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION); + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), true); // Fail on suboptimal + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should fail during initialization due to suboptimal compatibility + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("suboptimal")); +} + +// Test compatibility validation with unsupported status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Unsupported) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string stored_compatibility_string = "old_compatibility_v1.0"; + const std::string current_compatibility_string = "new_compatibility_v2.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(current_compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED); + + // Model has old compatibility string, EP has new one -> unsupported + std::map compatibility_info = {{ep_type, stored_compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), false); // Even with fail_on_suboptimal=false + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should fail during initialization due to unsupported compatibility + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported")); +} + +// Test compatibility validation with not applicable status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_NotApplicable) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(""); // Empty compatibility string + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + + // Model has some compatibility string, but EP returns not applicable + std::map compatibility_info = {{ep_type, "some_compatibility_string"}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed with not applicable status +} + +// Test missing compatibility info in model metadata +TEST_F(EpCompatibilityTest, TestMissingCompatibilityInfo) { + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString("some_compatibility_string"); + + // Use model without any compatibility metadata + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed when no compatibility info is present +} + +// Test EP validation failure +TEST_F(EpCompatibilityTest, TestEpValidationFailure) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetShouldFailValidation(true); // Force validation failure + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should handle EP validation failure gracefully + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure")); +} + +// Test session option configuration for fail on suboptimal +TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { + SessionOptions so; + + // Test default value + std::string config_value; + bool has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_FALSE(has_config); // Should not be set by default + + // Test setting the option + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "1")); + has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_TRUE(has_config); + EXPECT_EQ(config_value, "1"); + + // Test setting to disabled + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "0")); + has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_TRUE(has_config); + EXPECT_EQ(config_value, "0"); +} diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index c6d958536f488..324394798863c 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -74,8 +74,7 @@ static Ort::Session GetSessionObj(Ort::Env& env, T model_uri, int provider_type) if (provider_type == 1) { #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); std::cout << "Running simple inference with cuda provider" << std::endl; #else diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 5c81696d5c57e..a22375320edae 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -30,8 +30,14 @@ static const onnxruntime::perftest::PerformanceTestConfig& DefaultPerformanceTes return default_config; } -ABSL_FLAG(std::string, f, "", "Specifies a free dimension by name to override to a specific value for performance optimization."); -ABSL_FLAG(std::string, F, "", "Specifies a free dimension by denotation to override to a specific value for performance optimization."); +ABSL_FLAG(std::string, f, "", + "Specifies a free dimension by name to override to a specific value for performance optimization.\n" + "[Usage]: -f \"dimension_name1:override_value1\" -f \"dimension_name2:override_value2\" ... or" + " -f \"dimension_name1:override_value1 dimension_name2:override_value2 ... \". Override value must > 0."); +ABSL_FLAG(std::string, F, "", + "Specifies a free dimension by denotation to override to a specific value for performance optimization.\n" + "[Usage]: -f \"dimension_denotation1:override_value1\" -f \"dimension_denotation2:override_value2\" ... or" + " -f \"dimension_denotation1:override_value1 dimension_denotation2 : override_value2... \". Override value must > 0."); ABSL_FLAG(std::string, m, "duration", "Specifies the test mode. Value could be 'duration' or 'times'."); ABSL_FLAG(std::string, e, "cpu", "Specifies the provider 'cpu','cuda','dnnl','tensorrt', 'nvtensorrtrtx', 'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'."); ABSL_FLAG(size_t, r, DefaultPerformanceTestConfig().run_config.repeated_times, "Specifies the repeated times if running in 'times' test mode."); @@ -168,26 +174,6 @@ ABSL_FLAG(bool, h, false, "Print program usage."); namespace onnxruntime { namespace perftest { -static bool ParseDimensionOverride(std::string& dim_identifier, int64_t& override_val, const char* option) { - std::basic_string free_dim_str(option); - size_t delimiter_location = free_dim_str.find(":"); - if (delimiter_location >= free_dim_str.size() - 1) { - return false; - } - dim_identifier = free_dim_str.substr(0, delimiter_location); - std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos); - ORT_TRY { - override_val = std::stoll(override_val_str.c_str()); - if (override_val <= 0) { - return false; - } - } - ORT_CATCH(...) { - return false; - } - return true; -} - std::string CustomUsageMessage() { std::ostringstream oss; oss << "onnxruntime_perf_test [options...] model_path [result_file]\n\n"; @@ -212,20 +198,21 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a absl::SetFlagsUsageConfig(config); absl::SetProgramUsageMessage(CustomUsageMessage()); - auto utf8_strings = utils::ConvertArgvToUtf8Strings(argc, argv); - auto utf8_argv = utils::CStringsFromStrings(utf8_strings); + auto utf8_argv_strings = utils::ConvertArgvToUtf8Strings(argc, argv); + auto utf8_argv = utils::CStringsFromStrings(utf8_argv_strings); auto positional = absl::ParseCommandLine(static_cast(utf8_argv.size()), utf8_argv.data()); // -f { const auto& dim_override_str = absl::GetFlag(FLAGS_f); if (!dim_override_str.empty()) { - std::string dim_name; - int64_t override_val; - if (!ParseDimensionOverride(dim_name, override_val, dim_override_str.c_str())) { + // Abseil doesn't support the same option being provided multiple times - only the last occurrence is applied. + // To preserve the previous usage of '-f', where users may specify it multiple times to override different dimension names, + // we need to manually parse argv. + std::string option = "f"; + if (!ParseDimensionOverrideFromArgv(argc, utf8_argv_strings, option, test_config.run_config.free_dim_name_overrides)) { return false; } - test_config.run_config.free_dim_name_overrides[dim_name] = override_val; } } @@ -233,12 +220,11 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a { const auto& dim_override_str = absl::GetFlag(FLAGS_F); if (!dim_override_str.empty()) { - std::string dim_denotation; - int64_t override_val; - if (!ParseDimensionOverride(dim_denotation, override_val, dim_override_str.c_str())) { + // Same reason as '-f' above to manully parse argv. + std::string option = "F"; + if (!ParseDimensionOverrideFromArgv(argc, utf8_argv_strings, option, test_config.run_config.free_dim_denotation_overrides)) { return false; } - test_config.run_config.free_dim_denotation_overrides[dim_denotation] = override_val; } } diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 973baf774b024..513122609bb01 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -35,7 +35,7 @@ int real_main(int argc, char* argv[]) { } ORT_CATCH(const Ort::Exception& e) { ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment: %s \n", e.what()); + std::cerr << "Error creating environment: " << e.what() << std::endl; failed = true; }); } @@ -98,7 +98,7 @@ int main(int argc, char* argv[]) { } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); + std::cerr << ex.what() << std::endl; retval = -1; }); } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 7156a1eb5c347..f1a40b1da8651 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -179,53 +179,29 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kCudaExecutionProvider) { #ifdef USE_CUDA - const auto& api = Ort::GetApi(); - OrtCUDAProviderOptionsV2* cuda_options; - Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - std::vector option_keys, option_values; - // used to keep all option keys and value strings alive - std::list buffer; - buffer.emplace_back("cudnn_conv_algo_search"); - option_keys.push_back(buffer.back().c_str()); + Ort::CUDAProviderOptions cuda_options; + + const char* config_val = nullptr; switch (performance_test_config.run_config.cudnn_conv_algo) { case 0: - buffer.emplace_back("EXHAUSTIVE"); + config_val = "EXHAUSTIVE"; break; case 1: - buffer.emplace_back("HEURISTIC"); + config_val = "HEURISTIC"; break; default: - buffer.emplace_back("DEFAULT"); + config_val = "DEFAULT"; break; } - option_values.push_back(buffer.back().c_str()); - - buffer.emplace_back("do_copy_in_default_stream"); - option_keys.push_back(buffer.back().c_str()); - buffer.emplace_back(!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0"); - option_values.push_back(buffer.back().c_str()); + provider_options.emplace("cudnn_conv_algo_search", config_val); + provider_options.emplace("do_copy_in_default_stream", + (!performance_test_config.run_config.do_cuda_copy_in_separate_stream ? "1" : "0")); -#ifdef _MSC_VER std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); -#else - std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; -#endif + ParseSessionConfigs(ov_string, provider_options); - for (const auto& provider_option : provider_options) { - option_keys.push_back(provider_option.first.c_str()); - option_values.push_back(provider_option.second.c_str()); - } + cuda_options.Update(provider_options); - Ort::Status status(api.UpdateCUDAProviderOptions(cuda_options, - option_keys.data(), option_values.data(), option_keys.size())); - if (!status.IsOK()) { - OrtAllocator* allocator; - char* options; - Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); - Ort::ThrowOnError(api.GetCUDAProviderOptionsAsString(cuda_options, allocator, &options)); - ORT_THROW("[ERROR] [CUDA] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), - "\nSupported options are:\n", options); - } session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); if (performance_test_config.run_config.enable_cuda_io_binding) { device_memory_name_ = CUDA; @@ -235,12 +211,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name_ == onnxruntime::kTensorrtExecutionProvider) { #ifdef USE_TENSORRT - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* tensorrt_options; - Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options)); - std::unique_ptr rel_trt_options( - tensorrt_options, api.ReleaseTensorRTProviderOptions); - std::vector option_keys, option_values; + Ort::TensorRTProviderOptions tensorrt_options; // used to keep all option keys and value strings alive std::list buffer; @@ -250,25 +221,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; #endif ParseSessionConfigs(ov_string, provider_options); - for (const auto& provider_option : provider_options) { - option_keys.push_back(provider_option.first.c_str()); - option_values.push_back(provider_option.second.c_str()); - } - Ort::Status status(api.UpdateTensorRTProviderOptions(tensorrt_options, - option_keys.data(), option_values.data(), option_keys.size())); - if (!status.IsOK()) { - OrtAllocator* allocator; - char* options; - Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); - Ort::ThrowOnError(api.GetTensorRTProviderOptionsAsString(tensorrt_options, allocator, &options)); - ORT_THROW("[ERROR] [TensorRT] Configuring the CUDA options failed with message: ", status.GetErrorMessage(), - "\nSupported options are:\n", options); - } - + tensorrt_options.Update(provider_options); session_options.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); OrtCUDAProviderOptions cuda_options; - cuda_options.device_id = tensorrt_options->device_id; + cuda_options.device_id = static_cast(tensorrt_options)->device_id; cuda_options.cudnn_conv_algo_search = static_cast(performance_test_config.run_config.cudnn_conv_algo); cuda_options.do_copy_in_default_stream = !performance_test_config.run_config.do_cuda_copy_in_separate_stream; // TODO: Support arena configuration for users of perf test diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc index f4860b35c79da..5743346f8edf1 100644 --- a/onnxruntime/test/perftest/strings_helper.cc +++ b/onnxruntime/test/perftest/strings_helper.cc @@ -56,6 +56,53 @@ void ParseSessionConfigs(const std::string& configs_string, } } +bool ParseDimensionOverride(const std::string& input, std::map& free_dim_override_map) { + std::stringstream ss(input); + std::string free_dim_str; + + while (std::getline(ss, free_dim_str, ' ')) { + if (!free_dim_str.empty()) { + size_t delimiter_location = free_dim_str.find(":"); + if (delimiter_location >= free_dim_str.size() - 1) { + return false; + } + std::string dim_identifier = free_dim_str.substr(0, delimiter_location); + std::string override_val_str = free_dim_str.substr(delimiter_location + 1, std::string::npos); + ORT_TRY { + int64_t override_val = std::stoll(override_val_str.c_str()); + if (override_val <= 0) { + return false; + } + free_dim_override_map[dim_identifier] = override_val; + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << "Error parsing free dimension override value: " << override_val_str.c_str() << ", " << ex.what() << std::endl; + }); + return false; + } + } + } + + return true; +} + +bool ParseDimensionOverrideFromArgv(int argc, std::vector& argv, std::string& option, std::map& free_dim_override_map) { + for (int i = 1; i < argc; ++i) { + auto utf8_arg = argv[i]; + if (utf8_arg == ("-" + option) || utf8_arg == ("--" + option)) { + auto value_idx = i + 1; + if (value_idx >= argc || argv[value_idx][0] == '-') { + std::cerr << utf8_arg << " should be followed by a key-value pair." << std::endl; + return false; + } + + if (!ParseDimensionOverride(argv[value_idx], free_dim_override_map)) return false; + } + } + return true; +} + void ParseEpOptions(const std::string& input, std::vector>& result) { auto tokens = utils::SplitString(input, ";", true); diff --git a/onnxruntime/test/perftest/strings_helper.h b/onnxruntime/test/perftest/strings_helper.h index 621ab746273bd..a33b3d5089c9b 100644 --- a/onnxruntime/test/perftest/strings_helper.h +++ b/onnxruntime/test/perftest/strings_helper.h @@ -3,6 +3,7 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include +#include #include #include #include @@ -14,6 +15,10 @@ void ParseSessionConfigs(const std::string& configs_string, std::unordered_map& session_configs, const std::unordered_set& available_keys = {}); +bool ParseDimensionOverride(const std::string& input, std::map& free_dim_override_map); + +bool ParseDimensionOverrideFromArgv(int argc, std::vector& argv, std::string& option, std::map& free_dim_override_map); + void ParseEpList(const std::string& input, std::vector& result); void ParseEpOptions(const std::string& input, std::vector>& result); diff --git a/onnxruntime/test/platform/device_discovery_test.cc b/onnxruntime/test/platform/device_discovery_test.cc new file mode 100644 index 0000000000000..21ddf9a5b1cd7 --- /dev/null +++ b/onnxruntime/test/platform/device_discovery_test.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/device_discovery.h" + +#include "gtest/gtest.h" + +namespace onnxruntime::test { + +namespace { + +std::vector GetDevicesByType(OrtHardwareDeviceType device_type) { + std::vector result{}; + const auto& devices = DeviceDiscovery::GetDevices(); + std::copy_if(devices.begin(), devices.end(), std::back_inserter(result), + [device_type](const OrtHardwareDevice& device) { + return device.type == device_type; + }); + return result; +} + +} // namespace + +TEST(DeviceDiscoveryTest, HasCpuDevice) { + const auto cpu_devices = GetDevicesByType(OrtHardwareDeviceType_CPU); + ASSERT_GT(cpu_devices.size(), 0); + +#if !defined(__wasm__) + ASSERT_NE(cpu_devices[0].vendor_id, 0); +#endif // !defined(__WASM__) +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index d5f6f1ddf700e..a17982ecb5eab 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -22,6 +22,7 @@ #include #include "default_providers.h" #include "test/onnx/TestCase.h" +#include "test/util/include/api_asserts.h" #ifdef USE_DNNL #include "core/providers/dnnl/dnnl_provider_factory.h" @@ -59,17 +60,6 @@ extern std::unique_ptr ort_env; -// asserts that the OrtStatus* result of `status_expr` does not indicate an error -// note: this takes ownership of the OrtStatus* result -#define ASSERT_ORT_STATUS_OK(status_expr) \ - do { \ - if (OrtStatus* _status = (status_expr); _status != nullptr) { \ - std::unique_ptr _rel_status{ \ - _status, &OrtApis::ReleaseStatus}; \ - FAIL() << "OrtStatus error: " << OrtApis::GetErrorMessage(_rel_status.get()); \ - } \ - } while (false) - using namespace onnxruntime::common; namespace onnxruntime { @@ -179,17 +169,14 @@ TEST_P(ModelTest, Run) { ortso.SetLogId(ToUTF8String(test_case_name).c_str()); ortso.SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); if (provider_name == "cuda") { - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); - std::unique_ptr rel_cuda_options( - cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; - std::vector keys{"device_id", "use_tf32"}; - std::vector values; std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); - values.push_back(device_id.empty() ? "0" : device_id.c_str()); - values.push_back("0"); - ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + + std::unordered_map options; + options["device_id"] = (device_id.empty() ? "0" : device_id.c_str()); + options["use_tf32"] = "0"; // Disable TF32 for CUDA provider + cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "rocm") { @@ -199,11 +186,11 @@ TEST_P(ModelTest, Run) { #ifdef USE_DNNL else if (provider_name == "dnnl") { OrtDnnlProviderOptions* ep_option; - ASSERT_ORT_STATUS_OK(OrtApis::CreateDnnlProviderOptions(&ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::CreateDnnlProviderOptions(&ep_option)); std::unique_ptr rel_dnnl_options(ep_option, &OrtApis::ReleaseDnnlProviderOptions); ep_option->use_arena = 0; - ASSERT_ORT_STATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_Dnnl(ortso, ep_option)); + ASSERT_CXX_ORTSTATUS_OK(OrtApis::SessionOptionsAppendExecutionProvider_Dnnl(ortso, ep_option)); } #endif else if (provider_name == "tensorrt") { @@ -211,24 +198,17 @@ TEST_P(ModelTest, Run) { OrtTensorRTProviderOptionsV2 params; ortso.AppendExecutionProvider_TensorRT_V2(params); } else { - OrtTensorRTProviderOptionsV2* ep_option = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateTensorRTProviderOptions(&ep_option)); - std::unique_ptr - rel_cuda_options(ep_option, &OrtApis::ReleaseTensorRTProviderOptions); + Ort::TensorRTProviderOptions ep_option; ortso.AppendExecutionProvider_TensorRT_V2(*ep_option); } // Enable CUDA fallback - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); - std::unique_ptr rel_cuda_options( - cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; - std::vector keys{"device_id", "use_tf32"}; - std::vector values; std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); - values.push_back(device_id.empty() ? "0" : device_id.c_str()); - values.push_back("0"); - ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 2)); + std::unordered_map options; + options["device_id"] = (device_id.empty() ? "0" : device_id.c_str()); + options["use_tf32"] = "0"; // Disable TF32 for CUDA provider + cuda_options.Update(options); ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "migraphx") { @@ -240,27 +220,27 @@ TEST_P(ModelTest, Run) { } #ifdef USE_NNAPI else if (provider_name == "nnapi") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_Nnapi(ortso, 0)); } #endif #ifdef USE_VSINPU else if (provider_name == "vsinpu") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_VSINPU(ortso)); } #endif #ifdef USE_RKNPU else if (provider_name == "rknpu") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_Rknpu(ortso)); } #endif #ifdef USE_ACL else if (provider_name == "acl") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); } #endif #ifdef USE_ARMNN else if (provider_name == "armnn") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso)); + ASSERT_CXX_ORTSTATUS_OK(OrtSessionOptionsAppendExecutionProvider_ArmNN(ortso)); } #endif #ifdef USE_XNNPACK @@ -300,11 +280,11 @@ TEST_P(ModelTest, Run) { std::unordered_map feeds; l->LoadTestData(task_id, holder, feeds, true); size_t output_count; - ASSERT_ORT_STATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); + ASSERT_ORTSTATUS_OK(OrtApis::SessionGetOutputCount(ort_session, &output_count)); // Create output feed std::vector output_names(output_count); for (size_t i = 0; i != output_count; ++i) { - ASSERT_ORT_STATUS_OK( + ASSERT_ORTSTATUS_OK( OrtApis::SessionGetOutputName(ort_session, i, default_allocator.get(), &output_names[i])); } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 19505da1bbe56..2327bc2094d1a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -1,25 +1,16 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" #include "test/framework/test_utils.h" -#include "gtest/gtest.h" + #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" #include "test/common/random_generator.h" #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" -#include "test/util/include/api_asserts.h" -#include "test/util/include/asserts.h" -#include -#include -#include -#include -#include #include -#include #include using namespace std; @@ -30,200 +21,6 @@ namespace onnxruntime { namespace test { -template -class NvExecutionProviderTest : public ::testing::Test { - protected: - std::string getTypeAsName() { - std::string dtype_name = ""; - if constexpr (std::is_same::value) { - dtype_name = "fp64"; - } else if constexpr (std::is_same::value) { - dtype_name = "fp32"; - } else if constexpr (std::is_same::value) { - dtype_name = "bf16"; - } else if constexpr (std::is_same::value) { - dtype_name = "fp16"; - } else if constexpr (std::is_same::value) { - dtype_name = "int8"; - } else if constexpr (std::is_same::value) { - dtype_name = "uint8"; - } else if constexpr (std::is_same::value) { - dtype_name = "int32"; - } else if constexpr (std::is_same::value) { - dtype_name = "int64"; - } - return dtype_name; - } -}; - -using NvExecutionProviderTestTypes = ::testing::Types; // double, -TYPED_TEST_SUITE(NvExecutionProviderTest, NvExecutionProviderTestTypes); - -std::string PathToUTF8(const PathString& path) { -#ifdef WIN32 - std::wstring_convert> converter; - return converter.to_bytes(path); -#else - return path.c_str(); -#endif -} - -void clearFileIfExists(PathString path) { - if (std::filesystem::exists(path)) { - std::filesystem::remove(path); - } -} - -template -void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, - const std::vector& expected_values) { - ASSERT_EQ(1, fetches.size()); - auto& rtensor = fetches.front().Get(); - TensorShape expected_shape(expected_dims); - ASSERT_EQ(expected_shape, rtensor.Shape()); - const std::vector found(rtensor.Data(), rtensor.Data() + expected_values.size()); - ASSERT_EQ(expected_values, found); -} - -/** - * Create a simple model with dynamic or non-dynamic input shape. - * \param model_name - model name - * \param graph_name - graph name - * \param dims - input dimensions - * \param add_fast_gelu - add FastGelu node which makes the whole model partition into TRT EP and CUDA EP subgraphs. - * - * input: "X", "Y" and "Z" - * you can specify input dimensions, for example (1, 3, 2), (1, 2) or (1, -1, -1)). Note: -1 means the dimension is dynamic. - * All three inputs have the same dimensions. - * output: "M" - * - * "X" "Y" - * \ / - * "Z" Add - * \ / - * Add - * / - * Add (+ float scalar "S") - * / - * "O" - * - * or - * - * "X" "Y" - * \ / - * "Z" Add - * \ / - * Add - * / - * FastGelu (This node will be placed on CUDA EP) - * / - * * Add (+ float scalar "S") - * / - * "O" - */ -static void CreateBaseModel(const PathString& model_name, - std::string graph_name, - std::vector dims, - bool add_fast_gelu = false, - ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); - auto& graph = model.MainGraph(); - std::vector inputs; - std::vector outputs; - - // FLOAT tensor - ONNX_NAMESPACE::TypeProto float_tensor; - float_tensor.mutable_tensor_type()->set_elem_type(dtype); - - for (auto dim : dims) { - float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - ONNX_NAMESPACE::TypeProto dyn_float_tensor; - dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); - - auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); - auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); - inputs.push_back(&input_arg_1); - inputs.push_back(&input_arg_2); - auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); - outputs.push_back(&output_arg); - graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); - - auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); - inputs.clear(); - inputs.push_back(&output_arg); - inputs.push_back(&input_arg_3); - - auto& output_arg_2 = graph.GetOrCreateNodeArg("node_2_out_1", &float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_2); - graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); - - inputs.clear(); - inputs.push_back(&output_arg_2); - - if (add_fast_gelu) { - auto& output_arg_3 = graph.GetOrCreateNodeArg("node_3_out_1", &dyn_float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_3); - - graph.AddNode("node_3", "FastGelu", "node 3.", inputs, outputs, - /* attributes */ nullptr, kMSDomain); - - inputs.clear(); - inputs.push_back(&output_arg_3); - } - - ONNX_NAMESPACE::TypeProto float_scalar; - float_scalar.mutable_tensor_type()->set_elem_type(dtype); - float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); - inputs.push_back(&input_scalar); - - auto& output_arg_4 = graph.GetOrCreateNodeArg("O", &dyn_float_tensor); - - outputs.clear(); - outputs.push_back(&output_arg_4); - graph.AddNode("node_5", "Add", "node 5.", inputs, outputs); - - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()); - status = onnxruntime::Model::Save(model, model_name); - ASSERT_TRUE(status.IsOK()); -} - -static Ort::IoBinding generate_io_binding(Ort::Session& session, std::map> shape_overwrites = {}) { - Ort::IoBinding binding(session); - auto allocator = Ort::AllocatorWithDefaultOptions(); - for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { - auto input_name = session.GetInputNameAllocated(input_idx, Ort::AllocatorWithDefaultOptions()); - auto full_tensor_info = session.GetInputTypeInfo(input_idx); - auto tensor_info = full_tensor_info.GetTensorTypeAndShapeInfo(); - auto shape = tensor_info.GetShape(); - auto type = tensor_info.GetElementType(); - if (shape_overwrites.find(input_name.get()) == shape_overwrites.end()) { - for (auto& v : shape) { - if (v == -1) { - v = 1; - } - } - } else { - shape = shape_overwrites[input_name.get()]; - } - auto input_value = Ort::Value::CreateTensor(allocator, - shape.data(), - shape.size(), - type); - binding.BindInput(input_name.get(), input_value); - } - - for (int output_idx = 0; output_idx < int(session.GetOutputCount()); ++output_idx) { - auto output_name = session.GetOutputNameAllocated(output_idx, Ort::AllocatorWithDefaultOptions()); - binding.BindOutput(output_name.get(), allocator.GetInfo()); - } - return binding; -} - TEST(NvExecutionProviderTest, ContextEmbedAndReload) { PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx"); PathString model_name_ctx = ORT_TSTR("nv_execution_provider_test_ctx.onnx"); @@ -233,11 +30,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { std::vector dims = {1, 3, 2}; CreateBaseModel(model_name, graph_name, dims); - - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -246,7 +38,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -261,7 +53,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReload) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -280,10 +72,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -292,7 +80,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -307,7 +95,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDynamic) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -329,10 +117,6 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - // AOT time { auto start = std::chrono::high_resolution_clock::now(); @@ -341,7 +125,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation AOT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -356,7 +140,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { Ort::RunOptions run_options; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name_ctx.c_str(), so); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); auto stop = std::chrono::high_resolution_clock::now(); std::cout << "Session creation JIT: " << std::chrono::duration_cast((stop - start)).count() << " ms" << std::endl; @@ -368,33 +152,71 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { } } -TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { - std::string dtype_name = this->getTypeAsName(); +std::string getTypeAsName(ONNX_NAMESPACE::TensorProto_DataType dtype) { + switch (dtype) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return "fp64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return "fp32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return "fp16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return "bf16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return "int64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return "int32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return "int8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return "uint8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return "int4"; + default: + return "Unkwon type"; + } +} + +class TypeTests : public ::testing::TestWithParam { + public: +}; + +TEST_P(TypeTests, IOTypes) { + const std::string dtype_name = getTypeAsName(GetParam()); ASSERT_FALSE(dtype_name.empty()); const std::string model_name_str = "nv_execution_provider_" + dtype_name + ".onnx"; const PathString model_name = ToPathString(model_name_str); - std::string graph_name = "test" + dtype_name; - std::vector dims = {1, -1, -1}; - - CreateBaseModel(model_name, graph_name, dims); + const std::string graph_name = "test" + dtype_name; + const std::vector dims = {1, 5, 10}; - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); + CreateBaseModel(model_name, graph_name, dims, false, GetParam()); // AOT time { Ort::SessionOptions so; Ort::RunOptions run_options; so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); auto io_binding = generate_io_binding(session_object); session_object.Run(run_options, io_binding); } } -#if defined(WIN32) +INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, + ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 + // disabled low precision integer types since a specific quantize/dequantize model is required + // ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + // ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 + ), + [](const testing::TestParamInfo& info) { return getTypeAsName(info.param); }); + +#ifdef _WIN32 static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -420,20 +242,16 @@ TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { CreateBaseModel(model_name, graph_name, dims); - auto env = Ort::Env(); - auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; - env.UpdateEnvWithCustomLogLevel(logging_level); - { - env.RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); + ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); Ort::SessionOptions so; so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); - Ort::Session session_object(env, model_name.c_str(), so); + Ort::Session session_object(*ort_env, model_name.c_str(), so); EXPECT_TRUE(SessionHasEp(session_object, kNvTensorRTRTXExecutionProvider)); } - env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); + ort_env->UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); } TEST(NvExecutionProviderTest, GetSharedAllocator) { @@ -580,7 +398,7 @@ TEST(NvExecutionProviderTest, DataTransfer) { device_tensor = Ort::Value(); } -#endif // defined(WIN32) +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc new file mode 100644 index 0000000000000..ce49ae81c81c0 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/common/path_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include + +extern std::unique_ptr ort_env; + +namespace onnxruntime { + +namespace test { + +RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, std::unordered_map& option_map) { + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; +#ifdef _WIN32 + /// Since this test runs after other tests that use registration interface this test has to use it as well + /// windows as otherwise the kernel registry inside the EP will not be populated. The legacy APis ony call the initialize once. + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + auto ep_devices = ort_env->GetEpDevices(); + Ort::ConstEpDevice selected_device; + for (auto& device : ep_devices) { + if (!std::strcmp(device.EpName(), kNvTensorRTRTXExecutionProvider)) { + selected_device = device; + } + } + session_options.AppendExecutionProvider_V2(*ort_env, {selected_device}, option_map); +#else + session_options.AppendExecutionProvider(onnxruntime::kNvTensorRTRTXExecutionProvider, option_map); +#endif + return nv_tensorrt_rtx_ep; +} + +std::vector readBinaryFile(const PathString& filename) { + std::ifstream file(filename, std::ios::binary); + if (!file.is_open()) { + throw std::runtime_error("Could not open file: " + PathToUTF8String(filename)); + } + + file.seekg(0, std::ios::end); + std::streamsize filesize = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(filesize); + if (!file.read(reinterpret_cast(buffer.data()), filesize)) { + throw std::runtime_error("Could not read file: " + PathToUTF8String(filename)); + } + + return buffer; +} + +struct CompileParam { + bool embed_mode; + bool bytestream_io; + bool external_initialzier_for_parser = false; + const std::string to_string() const { + return "embed_mode_" + std::to_string(embed_mode) + "_bytestream_io_" + std::to_string(bytestream_io) + "_ext_init_" + std::to_string(external_initialzier_for_parser); + ; + } +}; +class CompileApiTest + : public testing::TestWithParam { + public: + const CompileParam& GetCompileParam() const { + return GetParam(); + } +}; + +void SmallModelTest(CompileParam test_param, bool fully_supported_model) { + std::string test_name = test_param.to_string(); + if (!fully_supported_model) + test_name += "_fast_gelu"; + PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_" + test_name + ".onnx"); + PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_" + test_name + "_ctx.onnx"); + clearFileIfExists(model_name_ctx); + std::string graph_name = "test"; + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model_name, graph_name, dims, !fully_supported_model); + + Ort::SessionOptions session_options; + std::unordered_map option_map{ + {onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, std::to_string(test_param.external_initialzier_for_parser)}}; + auto ep = AppendTrtEtxEP(session_options, option_map); + + Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options); + model_compile_options.SetEpContextEmbedMode(test_param.embed_mode); + + void* output_context = nullptr; + size_t output_context_size = 0; + std::vector input_onnx; + if (test_param.bytestream_io) { + input_onnx = readBinaryFile(model_name); + model_compile_options.SetInputModelFromBuffer(input_onnx.data(), input_onnx.size()); + model_compile_options.SetOutputModelBuffer(Ort::AllocatorWithDefaultOptions(), &output_context, &output_context_size); + } else { + model_compile_options.SetInputModelPath(model_name.c_str()); + model_compile_options.SetOutputModelPath(model_name_ctx.c_str()); + } + // AOT time + ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK()); + + // JIT time + Ort::Session session_object{nullptr}; + if (test_param.bytestream_io) { + session_object = Ort::Session(*ort_env, output_context, output_context_size, session_options); + } else { + session_object = Ort::Session(*ort_env, model_name_ctx.c_str(), session_options); + } + auto io_binding = generate_io_binding(session_object); + Ort::RunOptions run_options; + session_object.Run(run_options, io_binding); +} + +TEST_P(CompileApiTest, SmallModel) { + const auto& test_param = GetCompileParam(); + SmallModelTest(test_param, true); +} + +TEST_P(CompileApiTest, SmallSplitModel) { + const auto& test_param = GetCompileParam(); + SmallModelTest(test_param, false); +} + +TEST_P(CompileApiTest, LargeModel) { + const auto& test_param = GetCompileParam(); + // with embed mode == 1 the resulting file will be over the 2GB proto limit + if (test_param.embed_mode == 1) { + GTEST_SKIP(); + } + std::string test_name = test_param.to_string(); + PathString model_name = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + ".onnx"); + PathString external_data_name = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + ".onnx_data"); + PathString model_name_ctx = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + "_ctx.onnx"); + PathString model_name_ctx_data = path_utils::MakePathString("nv_execution_provider_compile_large_" + test_name + "_ctx.onnx_data"); + clearFileIfExists(model_name_ctx); + clearFileIfExists(model_name_ctx_data); + // This accelerates test iterations if the large model was already generated + if (!std::filesystem::exists(model_name) || !std::filesystem::exists(external_data_name)) { + CreateLargeLLMModel(model_name, external_data_name); + } + + Ort::SessionOptions session_options; + std::unordered_map option_map{ + {onnxruntime::nv::provider_option_names::kUseExternalDataInitializer, + std::to_string(test_param.bytestream_io || test_param.external_initialzier_for_parser)}}; + auto ep = AppendTrtEtxEP(session_options, option_map); + + Ort::ModelCompilationOptions model_compile_options(*ort_env, session_options); + model_compile_options.SetEpContextEmbedMode(test_param.embed_mode); + + void* output_context = nullptr; + size_t output_context_size = 0; + std::vector input_onnx, input_data; + std::vector file_names; + std::vector file_buffers; + std::vector lengths; + if (test_param.bytestream_io) { + input_onnx = readBinaryFile(model_name); + input_data = readBinaryFile(external_data_name); + file_names = {external_data_name}; + file_buffers = {input_data.data()}; + lengths = {input_data.size()}; + session_options.AddExternalInitializersFromFilesInMemory(file_names, file_buffers, lengths); + + model_compile_options.SetInputModelFromBuffer(input_onnx.data(), input_onnx.size()); + model_compile_options.SetOutputModelBuffer(Ort::AllocatorWithDefaultOptions(), &output_context, &output_context_size); + } else { + model_compile_options.SetInputModelPath(model_name.c_str()); + model_compile_options.SetOutputModelPath(model_name_ctx.c_str()); + model_compile_options.SetOutputModelExternalInitializersFile(model_name_ctx_data.c_str(), 1024); + } + + // AOT time + ASSERT_TRUE(Ort::CompileModel(*ort_env, model_compile_options).IsOK()); + + // JIT time + std::unique_ptr session; + if (test_param.bytestream_io) { + session = std::make_unique(*ort_env, output_context, output_context_size, session_options); + } else { + session = std::make_unique(*ort_env, model_name_ctx.c_str(), session_options); + } + + auto io_binding = generate_io_binding(*session); + Ort::RunOptions run_options; + session->Run(run_options, io_binding); +} + +INSTANTIATE_TEST_SUITE_P( + NvExecutionProviderTest, CompileApiTest, + ::testing::Values( + CompileParam{true, false}, + CompileParam{false, false}, + CompileParam{true, true}, + CompileParam{false, true}, + // test with external initializers for parser + CompileParam{true, true, true}, + CompileParam{true, false, true}), + [](const testing::TestParamInfo& info) { + return info.param.to_string(); + }); + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index f0ce5c0b296ca..17182ab032f7a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -3,18 +3,26 @@ // Licensed under the MIT License. // registration/selection is only supported on windows as there's no device discovery on other platforms -#ifdef _WIN32 #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" #include +#include #include #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/api_asserts.h" +#include "core/graph/basic_types.h" +#include "core/graph/onnx_protobuf.h" +#include "core/graph/model_saving_options.h" +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" namespace onnxruntime { namespace test { +#ifdef _WIN32 Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; @@ -51,8 +59,410 @@ void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniqu c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); }); } +#endif // _WIN32 + +void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const PathString& external_initializer_file) { + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + std::vector inputs; + std::vector outputs; + + // FLOAT tensor + ONNX_NAMESPACE::TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(dtype); + + for (auto dim : dims) { + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + ONNX_NAMESPACE::TypeProto dyn_float_tensor; + dyn_float_tensor.mutable_tensor_type()->set_elem_type(dtype); + + auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); + inputs.push_back(&input_arg_1); + inputs.push_back(&input_arg_2); + auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); + outputs.push_back(&output_arg); + graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); + + auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); + inputs.clear(); + inputs.push_back(&output_arg); + inputs.push_back(&input_arg_3); + + auto& output_arg_2 = graph.GetOrCreateNodeArg("node_2_out_1", &float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_2); + graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); + + inputs.clear(); + inputs.push_back(&output_arg_2); + + if (add_fast_gelu) { + auto& output_arg_3 = graph.GetOrCreateNodeArg("node_3_out_1", &dyn_float_tensor); + outputs.clear(); + outputs.push_back(&output_arg_3); + + graph.AddNode("node_3", "FastGelu", "node 3.", inputs, outputs, + /* attributes */ nullptr, kMSDomain); + + inputs.clear(); + inputs.push_back(&output_arg_3); + } + + ONNX_NAMESPACE::TypeProto float_scalar; + float_scalar.mutable_tensor_type()->set_elem_type(dtype); + float_scalar.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + auto& input_scalar = graph.GetOrCreateNodeArg("S", &float_scalar); + inputs.push_back(&input_scalar); + + auto& output_arg_4 = graph.GetOrCreateNodeArg("O", &dyn_float_tensor); + + outputs.clear(); + outputs.push_back(&output_arg_4); + graph.AddNode("node_5", "Add", "node 5.", inputs, outputs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()); + if (!external_initializer_file.empty()) { + ModelSavingOptions save_options(128); + status = Model::SaveWithExternalInitializers(model, model_name, external_initializer_file, save_options); + } else { + status = Model::Save(model, model_name); + } + ASSERT_TRUE(status.IsOK()); +} + +// Helper to create large initializers +ONNX_NAMESPACE::TensorProto CreateLargeWeight( + const std::string& name, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const std::vector& shape, + float scale = 0.02f) { + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_name(name); + tensor.set_data_type(dtype); + for (auto d : shape) tensor.add_dims(d); + // Here we fill with random floats, but for real data, use your trained weights. + size_t total_size = 1; + for (int64_t d : shape) total_size *= d; + std::random_device rd; + std::default_random_engine rng(rd()); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + std::vector data(total_size); + std::normal_distribution dist(0.0f, scale); + for (auto& v : data) v = dist(rng); + tensor.set_raw_data(data.data(), total_size * sizeof(float)); + } else if (dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + std::vector data(total_size); + std::normal_distribution dist(0.0f, scale); + for (auto& v : data) v = MLFloat16(dist(rng)); + tensor.set_raw_data(data.data(), total_size * sizeof(MLFloat16)); + } else { + throw std::runtime_error("Unsupported data type for large weight"); + } + return tensor; +} + +// Helper to add a GroupQueryAttention node +onnxruntime::NodeArg& AddGroupQueryAttention( + onnxruntime::Graph& graph, + onnxruntime::NodeArg& query, + onnxruntime::NodeArg& key, + onnxruntime::NodeArg& value, + int batch_size, + int head_dim, + int seq_len, + int num_heads, + int kv_num_heads, + float scale, + ONNX_NAMESPACE::TensorProto_DataType dtype, + const std::string& node_name) { + // KV cache + ONNX_NAMESPACE::TypeProto key_type; + key_type.mutable_tensor_type()->set_elem_type(dtype); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(kv_num_heads); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_len); + key_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(head_dim); + auto& past_key = graph.GetOrCreateNodeArg(node_name + "_past_key", &key_type); + + ONNX_NAMESPACE::TypeProto value_type; + value_type.mutable_tensor_type()->set_elem_type(dtype); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(kv_num_heads); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_len); + value_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(head_dim); + auto& past_value = graph.GetOrCreateNodeArg(node_name + "_past_value", &value_type); + + // Output + auto& output = graph.GetOrCreateNodeArg(node_name + "_output", nullptr); + + // Create required initializers for GroupQueryAttention + ONNX_NAMESPACE::TensorProto seqlens_k_tensor; + seqlens_k_tensor.set_name(node_name + "_seqlens_k"); + seqlens_k_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + seqlens_k_tensor.add_dims(2); + seqlens_k_tensor.set_dims(0, batch_size); + seqlens_k_tensor.set_dims(0, 1); + seqlens_k_tensor.add_int32_data(seq_len - 1); // seqlens_k = total_sequence_length - 1 + graph.AddInitializedTensor(seqlens_k_tensor); + + ONNX_NAMESPACE::TensorProto total_seq_len_tensor; + total_seq_len_tensor.set_name(node_name + "_total_sequence_length"); + total_seq_len_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + total_seq_len_tensor.add_int32_data(seq_len); + graph.AddInitializedTensor(total_seq_len_tensor); + + // Get the initializers that were created for this node + auto* seqlens_k = graph.GetNodeArg(node_name + "_seqlens_k"); + auto* total_sequence_length = graph.GetNodeArg(node_name + "_total_sequence_length"); + + auto& present_value = graph.GetOrCreateNodeArg(node_name + "_present_value", nullptr); + auto& present_key = graph.GetOrCreateNodeArg(node_name + "_present_key", nullptr); + + // Inputs - GroupQueryAttention requires at least 7 inputs (query, key, value, past_key, past_value, seqlens_k, total_sequence_length) + std::vector inputs = { + &query, // 0: query + &key, // 1: key + &value, // 2: value + &past_key, // 3: past_key (optional) + &past_value, // 4: past_value (optional) + seqlens_k, // 5: seqlens_k (required) + total_sequence_length, // 6: total_sequence_length (required) + // nullptr, // 7: cos_cache (optional) + // nullptr, // 8: sin_cache (optional) + // nullptr, // 9: position_ids (optional) + // nullptr, // 10: attention_bias (optional) + // nullptr // 11: head_sink (optional) + }; + + // Attributes + NodeAttributes attrs; + ONNX_NAMESPACE::AttributeProto attr_heads; + attr_heads.set_name("num_heads"); + attr_heads.set_type(onnx::AttributeProto_AttributeType_INT); + attr_heads.set_i(num_heads); + attrs["num_heads"] = attr_heads; + ONNX_NAMESPACE::AttributeProto attr_kv_num_heads; + attr_kv_num_heads.set_name("kv_num_heads"); + attr_kv_num_heads.set_type(onnx::AttributeProto_AttributeType_INT); + attr_kv_num_heads.set_i(kv_num_heads); + attrs["kv_num_heads"] = attr_kv_num_heads; + ONNX_NAMESPACE::AttributeProto attr_scale; + attr_scale.set_name("scale"); + attr_scale.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_scale.set_f(scale); + attrs["scale"] = attr_scale; + + // Register node + graph.AddNode( + node_name, + "GroupQueryAttention", + "GroupQueryAttention Node", + inputs, + {&output, &present_key, &present_value}, + &attrs, + "com.microsoft"); + + return output; +} + +void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path) { + // Model parameters (example: 24 layers, 4096 hidden dim, 32 attention heads, 8 kv heads => GQA) + int batch_size = 1; + int num_layers = 32; + int hidden_dim = 2048; + int q_num_heads = 8; + int kv_num_heads = 1; // GQA: q_num_heads > kv_num_heads, and divisible. + int seq_length = 128; // Short, for demonstration. + int vocab_size = 32000; + auto dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + + // Set up model/graph + onnxruntime::Model model("LLM_With_GQA", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // Input + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(dtype); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(batch_size); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(seq_length); + input_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(hidden_dim); + auto& input = graph.GetOrCreateNodeArg("input", &input_type); + + auto* current_arg = &input; + + // Repeated layers: [Attention + MLP] + for (int l = 0; l < num_layers; ++l) { + // KV cache - initialize with zeros for the first forward pass + int head_dim = hidden_dim / q_num_heads; + + // Split Q, K, V + auto& q_split = graph.GetOrCreateNodeArg("q_split_" + std::to_string(l), nullptr); + auto& k_split = graph.GetOrCreateNodeArg("k_split_" + std::to_string(l), nullptr); + auto& v_split = graph.GetOrCreateNodeArg("v_split_" + std::to_string(l), nullptr); + constexpr bool split = false; + if constexpr (split) { + // Attention weights (Q, K, V projections) + auto wqkv = CreateLargeWeight("wqkv_" + std::to_string(l), + dtype, {hidden_dim, hidden_dim * 3}); + graph.AddInitializedTensor(wqkv); + + // Q = input @ wq, K = input @ wk, V = input @ wv + auto& qkv_arg = graph.GetOrCreateNodeArg("qkv_" + std::to_string(l), nullptr); + graph.AddNode("QKV_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wqkv.name())}, {&qkv_arg}); + + NodeAttributes attrs_split; + ONNX_NAMESPACE::AttributeProto attr_split_axis; + attr_split_axis.set_name("axis"); + attr_split_axis.set_type(onnx::AttributeProto_AttributeType_INT); + attr_split_axis.set_i(-1); + attrs_split["axis"] = attr_split_axis; + ONNX_NAMESPACE::AttributeProto attr_split_num_outputs; + attr_split_num_outputs.set_name("num_outputs"); + attr_split_num_outputs.set_type(onnx::AttributeProto_AttributeType_INT); + attr_split_num_outputs.set_i(3); + attrs_split["num_outputs"] = attr_split_num_outputs; + graph.AddNode("Q_Split_" + std::to_string(l), "Split", "", {&qkv_arg}, {&q_split, &k_split, &v_split}, &attrs_split); + } else { + // Attention weights (Q, K, V projections) + auto wq = CreateLargeWeight("wq_" + std::to_string(l), + dtype, {hidden_dim, hidden_dim}); + graph.AddInitializedTensor(wq); + auto wk = CreateLargeWeight("wk_" + std::to_string(l), + dtype, {hidden_dim, head_dim * kv_num_heads}); + graph.AddInitializedTensor(wk); + auto wv = CreateLargeWeight("wv_" + std::to_string(l), + dtype, {hidden_dim, head_dim * kv_num_heads}); + graph.AddInitializedTensor(wv); + + // Q = input @ wq, K = input @ wk, V = input @ wv + graph.AddNode("Q_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wq.name())}, {&q_split}); + graph.AddNode("K_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wk.name())}, {&k_split}); + graph.AddNode("V_Linear_" + std::to_string(l), "MatMul", "", {current_arg, graph.GetNodeArg(wv.name())}, {&v_split}); + } + // Reshape Q, K, V + auto& q_reshaped = graph.GetOrCreateNodeArg("q_reshaped_" + std::to_string(l), nullptr); + auto& k_reshaped = graph.GetOrCreateNodeArg("k_reshaped_" + std::to_string(l), nullptr); + auto& v_reshaped = graph.GetOrCreateNodeArg("v_reshaped_" + std::to_string(l), nullptr); + + ONNX_NAMESPACE::TensorProto q_shape_tensor; + q_shape_tensor.set_name("q_shape_" + std::to_string(l)); + q_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + q_shape_tensor.add_dims(3); + q_shape_tensor.add_int64_data(batch_size); + q_shape_tensor.add_int64_data(seq_length); + q_shape_tensor.add_int64_data(head_dim * q_num_heads); + graph.AddInitializedTensor(q_shape_tensor); + + ONNX_NAMESPACE::TensorProto k_shape_tensor; + k_shape_tensor.set_name("k_shape_" + std::to_string(l)); + k_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + k_shape_tensor.add_dims(3); + k_shape_tensor.add_int64_data(batch_size); + k_shape_tensor.add_int64_data(seq_length); + k_shape_tensor.add_int64_data(head_dim * kv_num_heads); + graph.AddInitializedTensor(k_shape_tensor); + + ONNX_NAMESPACE::TensorProto v_shape_tensor; + v_shape_tensor.set_name("v_shape_" + std::to_string(l)); + v_shape_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + v_shape_tensor.add_dims(3); + v_shape_tensor.add_int64_data(batch_size); + v_shape_tensor.add_int64_data(seq_length); + v_shape_tensor.add_int64_data(head_dim * kv_num_heads); + graph.AddInitializedTensor(v_shape_tensor); + + graph.AddNode("Q_Reshape_" + std::to_string(l), "Reshape", "", {&q_split, graph.GetNodeArg(q_shape_tensor.name())}, {&q_reshaped}); + graph.AddNode("K_Reshape_" + std::to_string(l), "Reshape", "", {&k_split, graph.GetNodeArg(k_shape_tensor.name())}, {&k_reshaped}); + graph.AddNode("V_Reshape_" + std::to_string(l), "Reshape", "", {&v_split, graph.GetNodeArg(v_shape_tensor.name())}, {&v_reshaped}); + + // Replace standard attention with GQA + auto& attn_out = AddGroupQueryAttention( + graph, q_reshaped, k_reshaped, v_reshaped, + batch_size, head_dim, seq_length, q_num_heads, kv_num_heads, + 1.0f, dtype, + "GQA_" + std::to_string(l)); + + // Add an MLP block: (Linear + Activation + Linear) + auto w1 = CreateLargeWeight("mlp_w1_" + std::to_string(l), dtype, {hidden_dim, hidden_dim * 4}); + auto w2 = CreateLargeWeight("mlp_w2_" + std::to_string(l), dtype, {hidden_dim * 4, hidden_dim}); + graph.AddInitializedTensor(w1); + graph.AddInitializedTensor(w2); + + auto& mlp_hidden = graph.GetOrCreateNodeArg("mlp_hidden_" + std::to_string(l), nullptr); + graph.AddNode("MLP_1_" + std::to_string(l), "MatMul", "", {&attn_out, graph.GetNodeArg(w1.name())}, {&mlp_hidden}); + auto& relu_out = graph.GetOrCreateNodeArg("relu_" + std::to_string(l), nullptr); + graph.AddNode("Relu_" + std::to_string(l), "Relu", "", {&mlp_hidden}, {&relu_out}); + auto& mlp_out = graph.GetOrCreateNodeArg("mlp_out_" + std::to_string(l), nullptr); + graph.AddNode("MLP_2_" + std::to_string(l), "MatMul", "", {&relu_out, graph.GetNodeArg(w2.name())}, {&mlp_out}); + current_arg = &mlp_out; // For next layer. + } + + // Final projection to vocab + auto w_logits = CreateLargeWeight("w_logits", + dtype, {hidden_dim, vocab_size}); + graph.AddInitializedTensor(w_logits); + auto& output = graph.GetOrCreateNodeArg("logits", nullptr); + graph.AddNode("Output_Linear", "MatMul", "", {current_arg, graph.GetNodeArg(w_logits.name())}, {&output}); + + // Validate, Write as large model with external data + auto status = graph.Resolve(); + if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); + + onnxruntime::ModelSavingOptions save_options(128); + status = onnxruntime::Model::SaveWithExternalInitializers( + model, model_path, external_data_path, save_options); + if (!status.IsOK()) throw std::runtime_error(status.ErrorMessage()); +} + +Ort::IoBinding generate_io_binding( + Ort::Session& session, + std::map> shape_overwrites, + OrtAllocator* allocator) { + Ort::IoBinding binding(session); + auto default_allocator = Ort::AllocatorWithDefaultOptions(); + if (allocator == nullptr) { + allocator = default_allocator; + } + const OrtMemoryInfo* info; + Ort::ThrowOnError(Ort::GetApi().AllocatorGetInfo(allocator, &info)); + Ort::MemoryInfo mem_info(info->name, info->alloc_type, info->device.Id(), info->mem_type); + + for (int input_idx = 0; input_idx < int(session.GetInputCount()); ++input_idx) { + auto input_name = session.GetInputNameAllocated(input_idx, Ort::AllocatorWithDefaultOptions()); + auto full_tensor_info = session.GetInputTypeInfo(input_idx); + auto tensor_info = full_tensor_info.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + auto type = tensor_info.GetElementType(); + if (shape_overwrites.find(input_name.get()) == shape_overwrites.end()) { + for (auto& v : shape) { + if (v == -1) { + v = 1; + } + } + } else { + shape = shape_overwrites[input_name.get()]; + } + auto input_value = Ort::Value::CreateTensor(allocator, + shape.data(), + shape.size(), + type); + binding.BindInput(input_name.get(), input_value); + } + + for (int output_idx = 0; output_idx < int(session.GetOutputCount()); ++output_idx) { + auto output_name = session.GetOutputNameAllocated(output_idx, Ort::AllocatorWithDefaultOptions()); + binding.BindOutput(output_name.get(), mem_info); + } + return binding; +} } // namespace test } // namespace onnxruntime - -#endif // _WIN32 diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h index ef14d3cb382c0..0f011af8211ca 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h @@ -5,9 +5,21 @@ #include #include +#include +#include + +#include +#include +#include +#include +#include +#include -#include "core/session/onnxruntime_cxx_api.h" #include "core/graph/constants.h" +#include "core/common/path_string.h" +#include "core/framework/tensor.h" +#include "core/framework/ort_value.h" +#include "test/util/include/api_asserts.h" namespace onnxruntime { namespace test { @@ -17,7 +29,7 @@ using RegisteredEpDeviceUniquePtr = std::unique_ptr> converter; + return converter.to_bytes(path); +#else + return path.c_str(); +#endif +} + +[[maybe_unused]] static void clearFileIfExists(PathString path) { + if (std::filesystem::exists(path)) { + std::filesystem::remove(path); + } +} + +template +static void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, + const std::vector& expected_values) { + ASSERT_EQ(1, fetches.size()); + auto& rtensor = fetches.front().Get(); + TensorShape expected_shape(expected_dims); + ASSERT_EQ(expected_shape, rtensor.Shape()); + const std::vector found(rtensor.Data(), rtensor.Data() + expected_values.size()); + ASSERT_EQ(expected_values, found); +} + +/** + * Create a simple model with dynamic or non-dynamic input shape. + * \param model_name - model name + * \param graph_name - graph name + * \param dims - input dimensions + * \param add_fast_gelu - add FastGelu node which makes the whole model partition into TRT EP and CUDA EP subgraphs. + * \param external_initializer_file - file name to save external initializers to + * + * input: "X", "Y" and "Z" + * you can specify input dimensions, for example (1, 3, 2), (1, 2) or (1, -1, -1)). Note: -1 means the dimension is dynamic. + * All three inputs have the same dimensions. + * output: "M" + * + * "X" "Y" + * \ / + * "Z" Add + * \ / + * Add + * / + * Add (+ float scalar "S") + * / + * "O" + * + * or + * + * "X" "Y" + * \ / + * "Z" Add + * \ / + * Add + * / + * FastGelu (This node will be placed on CUDA EP) + * / + * * Add (+ float scalar "S") + * / + * "O" + */ +void CreateBaseModel(const PathString& model_name, + std::string graph_name, + std::vector dims, + bool add_fast_gelu = false, + ONNX_NAMESPACE::TensorProto_DataType dtype = ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + const PathString& external_initializer_file = {}); + +void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path); + +Ort::IoBinding generate_io_binding( + Ort::Session& session, + std::map> shape_overwrites = {}, + OrtAllocator* allocator = nullptr); + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index a206644bc945e..74b37867b0060 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -232,7 +232,7 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { so.AppendExecutionProvider("QNN", options); // Invalid! Adds CPU EP to session, but also disables CPU fallback. - Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + so.AppendExecutionProvider_CPU(1); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "constant_floats.onnx"; @@ -285,7 +285,7 @@ TEST_F(QnnHTPBackendTests, TestConvWithExternalData) { so.AppendExecutionProvider("QNN", options); - Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1)); + so.AppendExecutionProvider_CPU(1); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "conv_qdq_external_ini.onnx"; diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc new file mode 100644 index 0000000000000..b349e0c40882f --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) + +namespace { + +GetQDQTestCaseFn BuildLPBQGemmTestCase() { + return [](ModelTestBuilder& builder) -> void { + // Define the test case for LPBQGemm fusion here + const int64_t input_channels = 16; + const int64_t output_channels = 16; + const int64_t blocks_per_axis = 4; + const std::vector input_shape{1, input_channels}; + auto input_def = TestInputDef(input_shape, false, -0.5f, 0.5f); + NodeArg* input = MakeTestInput(builder, input_def); + + // QuantizeLinear for Activation + NodeArg* act_ql_output = builder.MakeIntermediate(); + NodeArg* act_ql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_ql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("QuantizeLinear", {input, act_ql_scale, act_ql_zero_point}, {act_ql_output}); + + // DequantizeLinear for Activation + NodeArg* act_dql_output = builder.MakeIntermediate(); + NodeArg* act_dql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_dql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("DequantizeLinear", {act_ql_output, act_dql_scale, act_dql_zero_point}, {act_dql_output}); + + // DequantizeLinear for Scale + NodeArg* scale_dql_input = builder.MakeInitializer({blocks_per_axis, output_channels}, 1, 15); + NodeArg* scale_dql_scale = builder.MakeInitializer({output_channels}, 0.01f, 0.02f); + std::vector dql_zero_points_data = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + NodeArg* scale_dql_zero_point = builder.Make1DInitializer(dql_zero_points_data); + NodeArg* scale_dql_output = builder.MakeIntermediate(); + Node& scale_dql = builder.AddNode("DequantizeLinear", {scale_dql_input, scale_dql_scale, scale_dql_zero_point}, {scale_dql_output}); + scale_dql.AddAttribute("axis", static_cast(1)); + + // QuantizeLinear for Weight + NodeArg* w_ql_input = builder.MakeInitializer({input_channels, output_channels}, -1.0f, 1.0f); + std::vector zero_points_data; + size_t num_storage_elems = blocks_per_axis * output_channels; + zero_points_data.resize(Int4x2::CalcNumInt4Pairs(num_storage_elems)); + for (size_t i = 0; i < num_storage_elems; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + zero_points_data[r].SetElem(c, 0); + } + NodeArg* w_ql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_ql_output = builder.MakeIntermediate(); + Node& w_ql = builder.AddNode("QuantizeLinear", {w_ql_input, scale_dql_output, w_ql_zero_point}, {w_ql_output}); + w_ql.AddAttribute("axis", static_cast(0)); + w_ql.AddAttribute("block_size", static_cast(4)); + + // DequantizeLinear for Weight + NodeArg* w_dql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_dql_output = builder.MakeIntermediate(); + Node& w_dql = builder.AddNode("DequantizeLinear", {w_ql_output, scale_dql_output, w_dql_zero_point}, {w_dql_output}); + w_dql.AddAttribute("axis", static_cast(0)); + w_dql.AddAttribute("block_size", static_cast(4)); + + // Gemm + NodeArg* gemm_bias = builder.MakeInitializer({output_channels}, -1.0f, 1.0f); + NodeArg* gemm_output = builder.MakeIntermediate(); + builder.AddNode("Gemm", {act_dql_output, w_dql_output, gemm_bias}, {gemm_output}); + + // QuantizeLinear for Output + NodeArg* output_ql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_ql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_ql_output = builder.MakeIntermediate(); + builder.AddNode("QuantizeLinear", {gemm_output, output_ql_scale, output_ql_zero_point}, {output_ql_output}); + + // DequantizeLinear for Output + NodeArg* output_dql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_dql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_dql_output = builder.MakeOutput(); + builder.AddNode("DequantizeLinear", {output_ql_output, output_dql_scale, output_dql_zero_point}, {output_dql_output}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +#if defined(_WIN32) +// Graph fails to compose on ARM64 Windows since QNN 2.37.0 +TEST_F(QnnHTPBackendTests, DISABLED_LPBQGemmFusion) { +#else +TEST_F(QnnHTPBackendTests, LPBQGemmFusion) { +#endif + ProviderOptions provider_options = GetProviderOptions(); + RunQnnModelTest(BuildLPBQGemmTestCase(), + provider_options, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::Some, + /*fp32_abs_err=*/1e-2f, + /*log_severity =*/logging::Severity::kERROR, + /*verify_outputs=*/false); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc new file mode 100644 index 0000000000000..8f63ccd5f2cd1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) + +namespace { + +GetQDQTestCaseFn BuildLPBQMatMulTestCase() { + return [](ModelTestBuilder& builder) -> void { + // Define the test case for LPBQGemm fusion here + const int64_t input_channels = 16; + const int64_t output_channels = 16; + const int64_t blocks_per_axis = 4; + const std::vector input_shape{1, input_channels}; + auto input_def = TestInputDef(input_shape, false, -0.5f, 0.5f); + NodeArg* input = MakeTestInput(builder, input_def); + + // QuantizeLinear for Activation + NodeArg* act_ql_output = builder.MakeIntermediate(); + NodeArg* act_ql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_ql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("QuantizeLinear", {input, act_ql_scale, act_ql_zero_point}, {act_ql_output}); + + // DequantizeLinear for Activation + NodeArg* act_dql_output = builder.MakeIntermediate(); + NodeArg* act_dql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_dql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("DequantizeLinear", {act_ql_output, act_dql_scale, act_dql_zero_point}, {act_dql_output}); + + // DequantizeLinear for Scale + NodeArg* scale_dql_input = builder.MakeInitializer({blocks_per_axis, output_channels}, 1, 16); + NodeArg* scale_dql_scale = builder.MakeInitializer({output_channels}, 0.01f, 0.02f); + std::vector dql_zero_points_data = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + NodeArg* scale_dql_zero_point = builder.Make1DInitializer(dql_zero_points_data); + NodeArg* scale_dql_output = builder.MakeIntermediate(); + Node& scale_dql = builder.AddNode("DequantizeLinear", {scale_dql_input, scale_dql_scale, scale_dql_zero_point}, {scale_dql_output}); + scale_dql.AddAttribute("axis", static_cast(1)); + + // QuantizeLinear for Weight + NodeArg* w_ql_input = builder.MakeInitializer({input_channels, output_channels}, -2.0f, 2.0f); + std::vector zero_points_data; + size_t num_storage_elems = blocks_per_axis * output_channels; + zero_points_data.resize(Int4x2::CalcNumInt4Pairs(num_storage_elems)); + for (size_t i = 0; i < num_storage_elems; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + zero_points_data[r].SetElem(c, 0); + } + NodeArg* w_ql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_ql_output = builder.MakeIntermediate(); + Node& w_ql = builder.AddNode("QuantizeLinear", {w_ql_input, scale_dql_output, w_ql_zero_point}, {w_ql_output}); + w_ql.AddAttribute("axis", static_cast(0)); + w_ql.AddAttribute("block_size", static_cast(4)); + + // DequantizeLinear for Weight + NodeArg* w_dql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_dql_output = builder.MakeIntermediate(); + Node& w_dql = builder.AddNode("DequantizeLinear", {w_ql_output, scale_dql_output, w_dql_zero_point}, {w_dql_output}); + w_dql.AddAttribute("axis", static_cast(0)); + w_dql.AddAttribute("block_size", static_cast(4)); + + // MatMul + NodeArg* matmul_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {act_dql_output, w_dql_output}, {matmul_output}); + + // QuantizeLinear for Output + NodeArg* output_ql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_ql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_ql_output = builder.MakeIntermediate(); + builder.AddNode("QuantizeLinear", {matmul_output, output_ql_scale, output_ql_zero_point}, {output_ql_output}); + + // DequantizeLinear for Output + NodeArg* output_dql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_dql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_dql_output = builder.MakeOutput(); + builder.AddNode("DequantizeLinear", {output_ql_output, output_dql_scale, output_dql_zero_point}, {output_dql_output}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +#if defined(_WIN32) +// Graph fails to compose on ARM64 Windows since QNN 2.37.0 +TEST_F(QnnHTPBackendTests, DISABLED_LPBQMatMulFusion) { +#else +TEST_F(QnnHTPBackendTests, LPBQMatMulFusion) { +#endif + ProviderOptions provider_options = GetProviderOptions(); + RunQnnModelTest(BuildLPBQMatMulTestCase(), + provider_options, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::Some, + /*fp32_abs_err=*/1e-2f, + /*log_severity =*/logging::Severity::kERROR, + /*verify_outputs=*/false); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/where_htp_test.cc b/onnxruntime/test/providers/qnn/where_htp_test.cc index bb3e229bbc9f8..95a9f3dac9cb7 100644 --- a/onnxruntime/test/providers/qnn/where_htp_test.cc +++ b/onnxruntime/test/providers/qnn/where_htp_test.cc @@ -86,7 +86,8 @@ static void RunWhereQDQTest(const TestInputDef& condition_def, } // Check that QNN compiles DQ -> Where -> Q as a single unit. -TEST_F(QnnHTPBackendTests, WhereQDQU8) { +// Fails since QNN 2.37.1: Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_WhereQDQU8) { RunWhereQDQTest(TestInputDef({4, 3, 2}, false, {true, false, true, false, true, false, true, false, true, false, true, false, @@ -99,7 +100,8 @@ TEST_F(QnnHTPBackendTests, WhereQDQU8) { // Check that QNN compiles DQ -> Where -> Q as a single unit. // Check QNN Where works with broadcast -TEST_F(QnnHTPBackendTests, WhereBroadcastU8) { +// Fails since QNN 2.37.1: Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_WhereBroadcastU8) { RunWhereQDQTest(TestInputDef({2}, false, {true, false}), TestInputDef({4, 3, 2}, true, -2.0f, 2.0f), TestInputDef({1}, true, {3.0f}), diff --git a/onnxruntime/test/shared_lib/test_allocator.cc b/onnxruntime/test/shared_lib/test_allocator.cc index 29f3dfad0f11d..bf9e54e8b3c7b 100644 --- a/onnxruntime/test/shared_lib/test_allocator.cc +++ b/onnxruntime/test/shared_lib/test_allocator.cc @@ -45,12 +45,10 @@ TEST(CApiTest, DefaultAllocator) { TEST(CApiTest, CustomAllocator) { constexpr PATH_TYPE model_path = TSTR("testdata/mul_1.onnx"); - const auto& api = Ort::GetApi(); - // Case 1: Register a custom allocator. { MockedOrtAllocator mocked_allocator; - ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + ort_env->RegisterAllocator(&mocked_allocator); Ort::SessionOptions session_options; session_options.AddConfigEntry("session.use_env_allocators", "1"); @@ -62,14 +60,14 @@ TEST(CApiTest, CustomAllocator) { ASSERT_EQ(mocked_allocator.NumAllocations(), std::stoll(stats.GetValue("NumAllocs"))); ASSERT_EQ(mocked_allocator.NumReserveAllocations(), std::stoll(stats.GetValue("NumReserves"))); - ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(mocked_allocator.Info()); } // Case 2: Register a custom allocator with an older API version which does not support GetStats. { MockedOrtAllocator mocked_allocator; mocked_allocator.version = 22; - ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + ort_env->RegisterAllocator(&mocked_allocator); Ort::SessionOptions session_options; session_options.AddConfigEntry("session.use_env_allocators", "1"); @@ -81,7 +79,7 @@ TEST(CApiTest, CustomAllocator) { auto stats = allocator.GetStats(); ASSERT_EQ(0, stats.GetKeyValuePairs().size()); - ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(mocked_allocator.Info()); } } #endif diff --git a/onnxruntime/test/shared_lib/test_data_copy.cc b/onnxruntime/test/shared_lib/test_data_copy.cc index 2294bb8d6fdff..e7d9d7715092b 100644 --- a/onnxruntime/test/shared_lib/test_data_copy.cc +++ b/onnxruntime/test/shared_lib/test_data_copy.cc @@ -31,9 +31,6 @@ using StreamUniquePtr = std::unique_ptrGetApi(ORT_API_VERSION); - #ifdef _WIN32 std::string cuda_lib = "onnxruntime_providers_cuda.dll"; #else @@ -47,10 +44,10 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // register the provider bridge based CUDA EP so allocator and data transfer is available // not all the CIs have the provider library in the expected place so we allow for that const char* ep_registration_name = "ORT CUDA"; - ASSERT_ORTSTATUS_OK(api->RegisterExecutionProviderLibrary(env, ep_registration_name, - ORT_TSTR("onnxruntime_providers_cuda"))); + ort_env->RegisterExecutionProviderLibrary(ep_registration_name, + ORT_TSTR("onnxruntime_providers_cuda")); - const OrtEpDevice* cuda_device = nullptr; + Ort::ConstEpDevice cuda_device{nullptr}; for (const auto& ep_device : ort_env->GetEpDevices()) { std::string vendor{ep_device.EpVendor()}; std::string name = {ep_device.EpName()}; @@ -70,13 +67,11 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // we pass in the CUDA cudaStream_t from the OrtSyncStream via provider options so need to create it upfront. // in the future the stream should be an input to the Session Run. - OrtSyncStream* stream = nullptr; - StreamUniquePtr stream_ptr; + Ort::SyncStream stream{nullptr}; if (use_streams) { - ASSERT_ORTSTATUS_OK(api->CreateSyncStreamForEpDevice(cuda_device, /*options*/ nullptr, &stream)); - stream_ptr = StreamUniquePtr(stream, [api](OrtSyncStream* stream) { api->ReleaseSyncStream(stream); }); + stream = cuda_device.CreateSyncStream(); - size_t stream_addr = reinterpret_cast(api->SyncStream_GetHandle(stream)); + size_t stream_addr = reinterpret_cast(stream.GetHandle()); options.AddConfigEntry("ep.cudaexecutionprovider.user_compute_stream", std::to_string(stream_addr).c_str()); // we explicitly specify user_compute_stream, so why do we also need to set has_user_compute_stream? options.AddConfigEntry("ep.cudaexecutionprovider.has_user_compute_stream", "1"); @@ -87,24 +82,27 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { size_t num_inputs = session.GetInputCount(); // find the input location so we know which inputs can be provided on device. - std::vector input_locations; - input_locations.resize(num_inputs, nullptr); - ASSERT_ORTSTATUS_OK(api->SessionGetMemoryInfoForInputs(session, input_locations.data(), num_inputs)); + auto input_locations = session.GetMemoryInfoForInputs(); + ASSERT_EQ(session.GetInputCount(), input_locations.size()); + + // Testing coverage + auto input_ep_devices = session.GetEpDeviceForInputs(); + ASSERT_EQ(session.GetInputCount(), input_ep_devices.size()); + + // This is for testing + auto output_locations = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(session.GetOutputCount(), output_locations.size()); std::vector cpu_tensors; // info for device copy - std::vector src_tensor_ptrs; - std::vector dst_tensor_ptrs; - - // values we'll call Run with - std::vector input_tensors; + std::vector device_tensors; ASSERT_EQ(num_inputs, 1); // create cpu based input data. Ort::AllocatorWithDefaultOptions cpu_allocator; - std::vector shape{1, 1, 28, 28}; + constexpr const std::array shape{1, 1, 28, 28}; std::vector input_data(28 * 28, 0.5f); Ort::Value input_value = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), input_data.data(), input_data.size(), @@ -112,15 +110,13 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { cpu_tensors.push_back(std::move(input_value)); for (size_t idx = 0; idx < num_inputs; ++idx) { - const OrtMemoryInfo* mem_info = input_locations[idx]; - OrtDeviceMemoryType mem_type = api->MemoryInfoGetDeviceMemType(mem_info); - OrtMemoryInfoDeviceType device_type; - api->MemoryInfoGetDeviceType(mem_info, &device_type); + auto mem_info = input_locations[idx]; + OrtDeviceMemoryType mem_type = mem_info.GetDeviceMemoryType(); + OrtMemoryInfoDeviceType device_type = mem_info.GetDeviceType(); if (device_type == OrtMemoryInfoDeviceType_GPU && mem_type == OrtDeviceMemoryType_DEFAULT) { // copy to device - OrtAllocator* allocator = nullptr; - ASSERT_ORTSTATUS_OK(api->GetSharedAllocator(env, mem_info, &allocator)); + auto allocator = ort_env->GetSharedAllocator(mem_info); // allocate new on-device memory auto src_shape = cpu_tensors[idx].GetTensorTypeAndShapeInfo().GetShape(); @@ -137,18 +133,12 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); */ - src_tensor_ptrs.push_back(cpu_tensors[idx]); - dst_tensor_ptrs.push_back(device_value); - input_tensors.push_back(std::move(device_value)); - } else { - // input is on CPU accessible memory. move to input_tensors - input_tensors.push_back(std::move(cpu_tensors[idx])); + device_tensors.push_back(std::move(device_value)); } } - if (!src_tensor_ptrs.empty()) { - ASSERT_ORTSTATUS_OK(api->CopyTensors(env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), stream, - src_tensor_ptrs.size())); + if (!device_tensors.empty()) { + ASSERT_CXX_ORTSTATUS_OK(ort_env->CopyTensors(cpu_tensors, device_tensors, stream)); // Stream support is still a work in progress. // @@ -160,18 +150,19 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { // iobinding.SynchronizeInputs(); // this doesn't actually require any bound inputs } - std::vector input_names = {"Input3"}; - std::vector output_names = {"Plus214_Output_0"}; + const auto& input_tensors = (!device_tensors.empty()) ? device_tensors : cpu_tensors; + + constexpr const std::array input_names = {"Input3"}; + constexpr const std::array output_names = {"Plus214_Output_0"}; Ort::Value output; session.Run(Ort::RunOptions{}, input_names.data(), input_tensors.data(), input_tensors.size(), output_names.data(), &output, 1); - const float* results = nullptr; - ASSERT_ORTSTATUS_OK(api->GetTensorData(output, reinterpret_cast(&results))); + const float* results = output.GetTensorData(); // expected results from the CPU EP. can check/re-create by running with PREFER_CPU. - std::vector expected = { + constexpr const std::array expected = { -0.701670527f, -0.583666623f, 0.0480501056f, @@ -192,7 +183,7 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { run_test(/*use_streams*/ true); run_test(/*use_streams*/ false); - ASSERT_ORTSTATUS_OK(api->UnregisterExecutionProviderLibrary(env, ep_registration_name)); + ort_env->UnregisterExecutionProviderLibrary(ep_registration_name); } #endif // USE_CUDA diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 56cc234a63832..786c0ba713b85 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1973,7 +1973,7 @@ static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model TEST(CApiTest, get_allocator_cpu) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); Ort::Allocator cpu_allocator(session, info_cpu); @@ -2018,8 +2018,7 @@ TEST(CApiTest, get_allocator_cpu) { #ifdef USE_CUDA TEST(CApiTest, get_allocator_cuda) { Ort::SessionOptions session_options; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); @@ -2101,7 +2100,7 @@ TEST(CApiTest, get_allocator_qnn_htp_shared) { TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); @@ -2175,11 +2174,10 @@ TEST(CApiTest, io_binding) { TEST(CApiTest, io_binding_cuda) { Ort::SessionOptions session_options; #ifdef USE_TENSORRT - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(session_options, 0)); + session_options.AppendExecutionProvider_TensorRT({}); #else - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); - session_options.AppendExecutionProvider_CUDA_V2(*options); + Ort::CUDAProviderOptions cuda_options; + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); @@ -2376,35 +2374,25 @@ TEST(CApiTest, io_binding_qnn_htp_shared) { #if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) || defined(USE_DML) TEST(CApiTest, basic_cuda_graph) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; #if defined(USE_TENSORRT) // Enable cuda graph in TRT provider option. - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - std::vector keys{"trt_cuda_graph_enable"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + Ort::TensorRTProviderOptions trt_options; + std::unordered_map trt_options_map = {{"trt_cuda_graph_enable", + "1"}}; + trt_options.Update(trt_options_map); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options_map = {{"enable_cuda_graph", + "1"}}; + cuda_options.Update(options_map); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); #elif defined(USE_ROCM) // Enable hip graph in rocm provider option. OrtROCMProviderOptions* rocm_options = nullptr; @@ -2699,7 +2687,7 @@ static void RunWithCudaGraphAnnotation(T& cg_data, } TEST(CApiTest, basic_cuda_graph_with_annotation) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; #ifdef USE_DML @@ -2712,17 +2700,11 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); #elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + std::unordered_map options_map = {{"enable_cuda_graph", "1"}}; + cuda_options.Update(options_map); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #elif defined(USE_ROCM) // Enable hip graph in rocm provider option. @@ -2771,21 +2753,15 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { #ifndef REDUCED_OPS_BUILD #if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { - const auto& api = Ort::GetApi(); + [[maybe_unused]] const auto& api = Ort::GetApi(); // Enable cuda graph in cuda provider option. - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr - rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); - std::vector keys{"enable_cuda_graph"}; - std::vector values{"1"}; - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + Ort::CUDAProviderOptions cuda_options; + const std::unordered_map options_map = {{"enable_cuda_graph", "1"}}; + cuda_options.Update(options_map); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( - static_cast(session_options), - rel_cuda_options.get()) == nullptr); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); @@ -3316,13 +3292,9 @@ TEST(CApiTest, model_metadata) { } TEST(CApiTest, get_available_providers) { - const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - int len = 0; - char** providers; - ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr); - ASSERT_GT(len, 0); - ASSERT_STREQ(providers[len - 1], "CPUExecutionProvider"); - ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr); + std::vector providers = Ort::GetAvailableProviders(); + ASSERT_GT(providers.size(), 0); + ASSERT_STREQ(providers.back().c_str(), "CPUExecutionProvider"); } TEST(CApiTest, get_available_providers_cpp) { @@ -3348,8 +3320,6 @@ TEST(CApiTest, get_build_info_string) { } TEST(CApiTest, TestSharedAllocators) { - OrtEnv* env_ptr = (OrtEnv*)(*ort_env); - // prepare inputs std::vector> inputs(1); auto& input = inputs.back(); @@ -3367,28 +3337,17 @@ TEST(CApiTest, TestSharedAllocators) { // Turn on sharing of the allocator between sessions session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1"); - const auto& api = Ort::GetApi(); - // CASE 1: We test creating and registering an ORT-internal allocator implementation instance // for sharing between sessions { - OrtMemoryInfo* mem_info = nullptr; - ASSERT_TRUE(api.CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info) == nullptr); - std::unique_ptr rel_info(mem_info, api.ReleaseMemoryInfo); - - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + auto mem_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::ArenaCfg arena_cfg(0, -1, -1, -1); // This creates an ORT-internal allocator instance and registers it in the environment for sharing // NOTE: On x86 builds arenas are not supported and will default to using non-arena based allocator - ASSERT_TRUE(api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg) == nullptr); - + ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); // Registration is always a replace operation - std::unique_ptr status_releaser( - api.CreateAndRegisterAllocator(env_ptr, mem_info, arena_cfg), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); + ort_env->CreateAndRegisterAllocator(mem_info, arena_cfg); { // create session 1 @@ -3414,7 +3373,7 @@ TEST(CApiTest, TestSharedAllocators) { // Remove the registered shared allocator for part 2 of this test // where-in we will register a custom allocator for the same device. - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, mem_info) == nullptr); + ort_env->UnregisterAllocator(mem_info); } // CASE 2: We test registering a custom allocator implementation @@ -3425,15 +3384,9 @@ TEST(CApiTest, TestSharedAllocators) { // need to be aligned for certain devices/build configurations/math libraries. // See docs/C_API.md for details. MockedOrtAllocator custom_allocator; - ASSERT_TRUE(api.RegisterAllocator(env_ptr, &custom_allocator) == nullptr); - + ort_env->RegisterAllocator(&custom_allocator); // Registration is always a replace operation - std::unique_ptr - status_releaser( - api.RegisterAllocator(env_ptr, &custom_allocator), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); - + ort_env->RegisterAllocator(&custom_allocator); { // Keep this scoped to destroy the underlying sessions after use // This should trigger frees in our custom allocator @@ -3472,7 +3425,7 @@ TEST(CApiTest, TestSharedAllocators) { // Remove the registered shared allocator from the global environment // (common to all tests) to prevent its accidental usage elsewhere - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, custom_allocator.Info()) == nullptr); + ort_env->UnregisterAllocator(custom_allocator.Info()); // Ensure that the registered custom allocator was indeed used for both sessions // We should have seen 2 allocations per session (one for the sole initializer @@ -3488,22 +3441,15 @@ TEST(CApiTest, TestSharedAllocators) { } #ifdef USE_CUDA { - OrtMemoryInfo* cuda_meminfo = nullptr; - ASSERT_TRUE(api.CreateMemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault, &cuda_meminfo) == nullptr); - std::unique_ptr rel_info(cuda_meminfo, api.ReleaseMemoryInfo); + auto cuda_meminfo = Ort::MemoryInfo("Cuda", OrtArenaAllocator, 0, OrtMemTypeDefault); - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfg(0, -1, -1, -1, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); - - std::vector keys, values; - ASSERT_TRUE(api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0) == nullptr); + Ort::ArenaCfg arena_cfg(0, -1, -1, -1); + ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, + cuda_meminfo, {}, arena_cfg); // Registration is always a replace operation - std::unique_ptr status_releaser( - api.CreateAndRegisterAllocatorV2(env_ptr, onnxruntime::kCudaExecutionProvider, cuda_meminfo, arena_cfg, keys.data(), values.data(), 0), - api.ReleaseStatus); - ASSERT_TRUE(status_releaser.get() == nullptr); + ort_env->CreateAndRegisterAllocatorV2(onnxruntime::kCudaExecutionProvider, + cuda_meminfo, {}, arena_cfg); { // create session 1 @@ -3530,7 +3476,7 @@ TEST(CApiTest, TestSharedAllocators) { nullptr); } - ASSERT_TRUE(api.UnregisterAllocator(env_ptr, cuda_meminfo) == nullptr); + ort_env->UnregisterAllocator(cuda_meminfo); } #endif } @@ -3558,16 +3504,10 @@ TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { Ort::Value val = Ort::Value::CreateTensor(mem_info, data, data_len, shape, shape_len); session_options.AddInitializer("W", val); - const auto& api = Ort::GetApi(); - - OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr; - ASSERT_TRUE(api.CreatePrepackedWeightsContainer(&prepacked_weights_container) == nullptr); - std::unique_ptr - rel_prepacked_weights_container(prepacked_weights_container, api.ReleasePrepackedWeightsContainer); - auto default_allocator = std::make_unique(); // create session 1 (using model path) + Ort::PrepackedWeightsContainer prepacked_weights_container; Ort::Session session1(*ort_env, MATMUL_MODEL_URI, session_options, prepacked_weights_container); RunSession(default_allocator.get(), session1, @@ -3664,12 +3604,11 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { Ort::SessionOptions session_options; #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); #else // arena is enabled but the sole initializer will still be allocated from non-arena memory - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); + session_options.AppendExecutionProvider_CPU(1); #endif // disable using arena for the sole initializer in the model @@ -3685,16 +3624,18 @@ TEST(CApiTest, AllocateInitializersFromNonArenaMemory) { // Usage example showing how to use CreateArenaCfgV2() API to configure the default memory CUDA arena allocator TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { - const auto& api = Ort::GetApi(); - Ort::SessionOptions session_options; - const char* keys[] = {"max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes", "max_power_of_two_extend_bytes"}; - const size_t values[] = {0 /*let ort pick default max memory*/, 0, 1024, 0, 256, 1L << 24}; + const std::unordered_map config_map = { + {"max_mem", 0}, // let ort pick default max memory + {"arena_extend_strategy", 0}, // use default extend strategy + {"initial_chunk_size_bytes", 1024}, // initial chunk size in bytes + {"max_dead_bytes_per_chunk", 0}, // no dead bytes per chunk + {"initial_growth_chunk_size_bytes", 256}, // initial growth chunk size in bytes + {"max_power_of_two_extend_bytes", 1L << 24} // max power of two extend bytes + }; - OrtArenaCfg* arena_cfg = nullptr; - ASSERT_TRUE(api.CreateArenaCfgV2(keys, values, 5, &arena_cfg) == nullptr); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + Ort::ArenaCfg arena_cfg(config_map); OrtCUDAProviderOptions cuda_provider_options = CreateDefaultOrtCudaProviderOptionsWithCustomStream(nullptr); cuda_provider_options.default_memory_arena_cfg = arena_cfg; @@ -3718,24 +3659,16 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #ifdef USE_TENSORRT TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { - const auto& api = Ort::GetApi(); - // Test input tensor which is shape tensor with explicit trt profile shapes Ort::SessionOptions session_options; - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - - const char* trt_profile_min_shapes = "data:2x2,shape:4x1"; - const char* trt_profile_max_shapes = "data:2x2,shape:4x1"; - const char* trt_profile_opt_shapes = "data:2x2,shape:4x1"; - std::vector keys{"trt_profile_min_shapes", "trt_profile_max_shapes", "trt_profile_opt_shapes"}; - std::vector values{trt_profile_min_shapes, trt_profile_max_shapes, trt_profile_opt_shapes}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); + Ort::TensorRTProviderOptions trt_options; + + std::unordered_map trt_options_map = { + {"trt_profile_min_shapes", "data:2x2,shape:4x1"}, + {"trt_profile_max_shapes", "data:2x2,shape:4x1"}, + {"trt_profile_opt_shapes", "data:2x2,shape:4x1"}}; + trt_options.Update(trt_options_map); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); auto model_path = ORT_TSTR("testdata/trt_reshape.onnx"); @@ -3758,37 +3691,24 @@ TEST(TensorrtExecutionProviderTest, ShapeTensorTest) { // Test input tensor which is shape tensor with implicit trt profile shapes Ort::SessionOptions session_options_2; - OrtTensorRTProviderOptionsV2* trt_options_2; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options_2) == nullptr); - std::unique_ptr - rel_trt_options_2(trt_options_2, api.ReleaseTensorRTProviderOptions); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options_2), - rel_trt_options_2.get()) == nullptr); + Ort::TensorRTProviderOptions trt_options_2; + session_options_2.AppendExecutionProvider_TensorRT_V2(*trt_options_2); Ort::Session session_2(*ort_env, model_path, session_options_2); - session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names)); + session_2.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, std::size(output_names)); } TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { - const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; - - OrtTensorRTProviderOptionsV2* trt_options; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr - rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + Ort::TensorRTProviderOptions trt_options; // updating provider option with user provided compute stream cudaStream_t compute_stream = nullptr; - void* user_compute_stream = nullptr; cudaStreamCreate(&compute_stream); - ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + trt_options.UpdateWithValue("user_compute_stream", compute_stream); + void* user_compute_stream = trt_options.GetOptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( - static_cast(session_options), - rel_trt_options.get()) == nullptr); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); Ort::Session session(*ort_env, MODEL_URI, session_options); Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -3901,36 +3821,30 @@ class CApiTensorRTTest : public testing::Test, public ::testing::WithParamInterf TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { std::string param = GetParam(); size_t pos = param.find("="); - std::string option_name = param.substr(0, pos); - std::string option_value = param.substr(pos + 1); + const std::string option_name = param.substr(0, pos); + const std::string option_value = param.substr(pos + 1); ASSERT_NE(pos, std::string::npos); - const auto& api = Ort::GetApi(); - OrtTensorRTProviderOptionsV2* trt_options; - OrtAllocator* allocator; - char* trt_options_str; - ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); - std::unique_ptr rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - const char* engine_cache_path = "./trt_engine_folder"; - std::vector keys{"device_id", "has_user_compute_stream", "trt_fp16_enable", "trt_int8_enable", "trt_engine_cache_enable", - "trt_engine_cache_path", option_name.c_str()}; - - std::vector values{"0", "0", "1", "0", "1", - engine_cache_path, option_value.c_str()}; + Ort::TensorRTProviderOptions trt_options; + std::unordered_map trt_options_map = { + {"device_id", "0"}, + {"has_user_compute_stream", "0"}, + {"trt_fp16_enable", "1"}, + {"trt_int8_enable", "0"}, + {"trt_engine_cache_enable", "1"}, + {"trt_engine_cache_path", engine_cache_path}, + {option_name, option_value}}; - ASSERT_TRUE(api.UpdateTensorRTProviderOptions(rel_trt_options.get(), keys.data(), values.data(), keys.size()) == nullptr); + trt_options.Update(trt_options_map); - ASSERT_TRUE(api.GetAllocatorWithDefaultOptions(&allocator) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsAsString(rel_trt_options.get(), allocator, &trt_options_str) == nullptr); - std::string s(trt_options_str); - ASSERT_TRUE(s.find(engine_cache_path) != std::string::npos); - ASSERT_TRUE(s.find(param.c_str()) != std::string::npos); - ASSERT_TRUE(api.AllocatorFree(allocator, (void*)trt_options_str) == nullptr); + std::string trt_options_str = trt_options.GetTensorRTProviderOptionsAsString(); + ASSERT_NE(trt_options_str.find(engine_cache_path), std::string::npos); + ASSERT_NE(trt_options_str.find(param), std::string::npos); Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(static_cast(session_options), rel_trt_options.get()) == nullptr); + session_options.AppendExecutionProvider_TensorRT_V2(*trt_options); // simple inference test // prepare inputs @@ -3973,40 +3887,30 @@ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest, // This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions/UpdateCUDAProviderOptionsWithValue APIs to configure and create a CUDA Execution Provider instance TEST(CApiTest, TestConfigureCUDAProviderOptions) { - const auto& api = Ort::GetApi(); - - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); - std::unique_ptr rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); + Ort::CUDAProviderOptions cuda_options; // Only test updating OrtCUDAProviderOptionsV2 instance with user provided compute stream not running the inference cudaStream_t compute_stream = nullptr; void* user_compute_stream = nullptr; cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); - ASSERT_TRUE(api.UpdateCUDAProviderOptionsWithValue(rel_cuda_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetCUDAProviderOptionsByName(rel_cuda_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + cuda_options.UpdateWithValue("user_compute_stream", compute_stream); + user_compute_stream = cuda_options.GetOptionByName("user_compute_stream"); ASSERT_TRUE(user_compute_stream == (void*)compute_stream); cudaStreamDestroy(compute_stream); - std::vector keys{ - "device_id", "has_user_compute_stream", "gpu_mem_limit", "arena_extend_strategy", - "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"}; - - std::vector values{ - "0", "0", "1024", "kSameAsRequested", - "DEFAULT", "1", "1"}; - - ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr); + std::unordered_map cuda_options_map = { + {"device_id", "0"}, + {"has_user_compute_stream", "0"}, + {"gpu_mem_limit", "1024"}, + {"arena_extend_strategy", "kSameAsRequested"}, + {"cudnn_conv_algo_search", "DEFAULT"}, + {"do_copy_in_default_stream", "1"}, + {"cudnn_conv_use_max_workspace", "1"}, + {"cudnn_conv1d_pad_to_nc1d", "1"}}; - OrtAllocator* allocator; - ASSERT_TRUE(api.GetAllocatorWithDefaultOptions(&allocator) == nullptr); + cuda_options.Update(cuda_options_map); - char* cuda_options_str = nullptr; - ASSERT_TRUE(api.GetCUDAProviderOptionsAsString(rel_cuda_options.get(), allocator, &cuda_options_str) == nullptr); - std::string s; - if (cuda_options_str != nullptr) { - s = std::string(cuda_options_str, strnlen(cuda_options_str, 2048)); - } + std::string s = cuda_options.GetCUDAProviderOptionsAsString(); ASSERT_TRUE(s.find("device_id=0") != std::string::npos); ASSERT_TRUE(s.find("gpu_mem_limit=1024") != std::string::npos); ASSERT_TRUE(s.find("arena_extend_strategy=kSameAsRequested") != std::string::npos); @@ -4015,10 +3919,8 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { ASSERT_TRUE(s.find("cudnn_conv_use_max_workspace=1") != std::string::npos); ASSERT_TRUE(s.find("cudnn_conv1d_pad_to_nc1d") != std::string::npos); - ASSERT_TRUE(api.AllocatorFree(allocator, (void*)cuda_options_str) == nullptr); - Ort::SessionOptions session_options; - ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2(static_cast(session_options), rel_cuda_options.get()) == nullptr); + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); // if session creation passes, model loads fine std::basic_string model_uri = MODEL_URI; @@ -4117,9 +4019,8 @@ TEST(CApiTest, GitHubIssue10179) { auto load_model_thread_fn = []() { try { const auto* model_path = MODEL_URI; - Ort::SessionOptions session_options{}; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::SessionOptions session_options; + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; } catch (const std::exception& e) { @@ -4150,8 +4051,7 @@ TEST(CApiTest, GitHubIssue10179) { TEST(CApiTest, TestCudaMemcpyToHostWithSequenceTensors) { const auto* model_path = SEQUENCE_MODEL_URI_2; Ort::SessionOptions session_options{}; - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; session_options.AppendExecutionProvider_CUDA_V2(*options); Ort::Session session{*ort_env, model_path, session_options}; diff --git a/onnxruntime/test/shared_lib/test_model_loading.cc b/onnxruntime/test/shared_lib/test_model_loading.cc index 89b12ec61649e..7268c351877f3 100644 --- a/onnxruntime/test/shared_lib/test_model_loading.cc +++ b/onnxruntime/test/shared_lib/test_model_loading.cc @@ -60,8 +60,7 @@ TEST(CApiTest, model_from_array) { create_session(so); #ifdef USE_CUDA - OrtCUDAProviderOptionsV2* options; - Ort::ThrowOnError(Ort::GetApi().CreateCUDAProviderOptions(&options)); + Ort::CUDAProviderOptions options; so.AppendExecutionProvider_CUDA_V2(*options); create_session(so); #endif diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index 3fbb294e1af49..d12a586f662ac 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -54,20 +54,15 @@ TEST(CApiTest, session_options_provider_interface_fail_add_openvino) { #if defined(USE_CUDA_PROVIDER_INTERFACE) // Test that loading CUDA EP when only the interface is built (but not the full EP) fails. TEST(CApiTest, session_options_provider_interface_fail_add_cuda) { - const OrtApi& api = Ort::GetApi(); Ort::SessionOptions session_options; - OrtCUDAProviderOptionsV2* cuda_options = nullptr; - Ort::Status status1 = Ort::Status{api.CreateCUDAProviderOptions(&cuda_options)}; - ASSERT_TRUE(status1.IsOK()); - - Ort::Status status2 = Ort::Status{api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options, - cuda_options)}; - ASSERT_FALSE(status2.IsOK()); - EXPECT_EQ(status2.GetErrorCode(), ORT_FAIL); - EXPECT_THAT(status2.GetErrorMessage(), testing::HasSubstr("Failed to load")); - - api.ReleaseCUDAProviderOptions(cuda_options); + Ort::CUDAProviderOptions cuda_options; + try { + session_options.AppendExecutionProvider_CUDA_V2(*cuda_options); + FAIL() << "Appending CUDA options have thrown exception"; + } catch (const Ort::Exception& ex) { + ASSERT_THAT(ex.what(), testing::HasSubstr("Failed to load")); + } } #endif // defined(USE_CUDA_PROVIDER_INTERFACE) diff --git a/onnxruntime/test/util/include/api_asserts.h b/onnxruntime/test/util/include/api_asserts.h index 423135f96fbcd..0be3b8bbb0764 100644 --- a/onnxruntime/test/util/include/api_asserts.h +++ b/onnxruntime/test/util/include/api_asserts.h @@ -10,36 +10,29 @@ #include "core/session/onnxruntime_cxx_api.h" // asserts for the public API -#define ASSERT_ORTSTATUS_OK(function) \ - do { \ - OrtStatusPtr _tmp_status = (function); \ - ASSERT_EQ(_tmp_status, nullptr) << Ort::GetApi().GetErrorMessage(_tmp_status); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define ASSERT_ORTSTATUS_OK(function) \ + do { \ + Ort::Status _tmp_status{(function)}; \ + ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) -#define EXPECT_ORTSTATUS_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - EXPECT_EQ(_tmp_status, nullptr) << Ort::GetApi().GetErrorMessage(_tmp_status); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define EXPECT_ORTSTATUS_OK(api, function) \ + do { \ + Ort::Status _tmp_status{(api->function)}; \ + EXPECT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ } while (false) -#define ASSERT_ORTSTATUS_NOT_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - ASSERT_NE(_tmp_status, nullptr); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define ASSERT_ORTSTATUS_NOT_OK(api, function) \ + do { \ + Ort::Status _tmp_status{(api->function)}; \ + ASSERT_TRUE(_tmp_status.IsOK()); \ } while (false) -#define EXPECT_ORTSTATUS_NOT_OK(api, function) \ - do { \ - OrtStatusPtr _tmp_status = (api->function); \ - EXPECT_NE(_tmp_status, nullptr); \ - if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ +#define EXPECT_ORTSTATUS_NOT_OK(api, function) \ + do { \ + Ort::Status _tmp_status{(api->function)}; \ + EXPECT_FALSE(_tmp_status.IsOK()); \ } while (false) -#define ASSERT_CXX_ORTSTATUS_OK(function) \ - do { \ - Ort::Status _tmp_status = (function); \ - ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ - } while (false) +#define ASSERT_CXX_ORTSTATUS_OK(function) \ + ASSERT_ORTSTATUS_OK(function) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0d51f66df33aa..4e7e03af84302 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1511,8 +1511,8 @@ def adb_push(src, dest, **kwargs): def adb_shell(*args, **kwargs): return run_subprocess([sdk_tool_paths.adb, "shell", *args], **kwargs) - def adb_install(*args, **kwargs): - return run_subprocess([sdk_tool_paths.adb, "install", *args], **kwargs) + def adb_logcat(*args, **kwargs): + return run_subprocess([sdk_tool_paths.adb, "logcat", *args], **kwargs) def run_adb_shell(cmd): # GCOV_PREFIX_STRIP specifies the depth of the directory hierarchy to strip and @@ -1538,6 +1538,17 @@ def run_adb_shell(cmd): ) context_stack.callback(android.stop_emulator, emulator_proc) + all_android_tests_passed = False + + def dump_logs_on_failure(): + if not all_android_tests_passed: + log.warning("Android test failed. Dumping logs.") + adb_logcat("-d") # dump logs + + context_stack.callback(dump_logs_on_failure) + + adb_logcat("-c") # clear logs + adb_push("testdata", device_dir, cwd=cwd) if is_linux() and os.path.exists("/data/onnx"): adb_push("/data/onnx", device_dir + "/test", cwd=cwd) @@ -1589,6 +1600,8 @@ def run_adb_shell(cmd): f"LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{device_dir} {device_dir}/onnxruntime_customopregistration_test" ) + all_android_tests_passed = True + def run_ios_tests(args, source_dir, config, cwd): is_targeting_iphone_simulator = "iphonesimulator" in args.apple_sysroot.lower() @@ -1695,8 +1708,10 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): run_ios_tests(args, source_dir, config, cwd) continue dll_path_list = [] - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: dll_path_list.append(os.path.join(args.tensorrt_home, "lib")) + if args.use_nv_tensorrt_rtx: + dll_path_list.append(os.path.join(args.tensorrt_rtx_home, "lib")) dll_path = None if len(dll_path_list) > 0: diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 91f35d2b54033..b062a3b64f6f3 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 028777756352d..40f24b1d2c886 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -55,7 +55,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index a5eb2ad216998..12cf8349a5575 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -83,7 +83,7 @@ stages: artifactName: 'onnxruntime-android-qnn-aar' packageName: 'onnxruntime-android-qnn' #TODO: get this information from the setup stage - QnnSDKVersion: '2.36.1.250708' + QnnSDKVersion: '2.37.1.250807' - template: nuget/templates/test_win.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 52268af15e776..f22a26cec6d88 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 526ed71df2006..ae2602c77d7a2 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index b4edf78c3b7bd..02aead3b3d3c7 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.37.1.250807 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 7af5334793c30..a94ceea6354e5 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: build_config displayName: Build Configuration @@ -77,4 +77,4 @@ extends: DoEsrp: ${{ parameters.DoEsrp }} ArtifactName: 'drop-nuget-qnn-arm64x' StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64x' - build_config: ${{ parameters.build_config }} \ No newline at end of file + build_config: ${{ parameters.build_config }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index f4a62208059c8..2ada8483ddec3 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.37.1.250807 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 7370910eb1e28..be61f652f7fc5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.37.1.250807' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index e4bfe20238770..c1720a2cac257 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.37.1.250807' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index bf65b0c54cf27..40511ee871163 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -43,7 +43,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 57703239fc594..73c774b9a45e9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.37.1.250807' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index d2e401f3f6ab4..8c15fe111593f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.37.1.250807' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml index b7fb8a51f28be..0ce6f3ec50a06 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.37.1.250807' steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 2168214527c91..3d662ffbb18dd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 3c2ef4741f049..09133499bc23f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -11,7 +11,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index c8d37457a1034..cd6a43a18991e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index e631e9d391a67..dd202270768af 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 - name: ENV_SETUP_SCRIPT type: string @@ -137,4 +137,4 @@ jobs: - script: | 7z x *.whl workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' \ No newline at end of file + displayName: 'unzip the package' diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 52d9eb139fab7..eeb8709e0dea2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.1.250708' + QnnSdk: '2.37.1.250807' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 66d1cd1687d99..a01e2bc921aea 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index ffeb577547f69..c350ba2ce402c 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.37.1.250807 jobs: - job: 'BUILD_QNN_EP'