diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 1bb7f219c9a45..f54f4a5a6f1ef 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -36,6 +36,7 @@ class GraphOptimizerRegistry; #include "core/framework/framework_provider_common.h" #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +#include "core/session/onnxruntime_c_api.h" struct OrtEpDevice; struct OrtRunOptions; @@ -322,6 +323,29 @@ class IExecutionProvider { virtual common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs); + /** + * Get the compatibility info for a compiled model. + * + * The execution provider determines this value, which denotes the compatibility of the compiled model with the EP. + * This is stored in the model metadata under a key associated with the EP type. + */ + virtual std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const { + // graph_viewer and model_metadata are not used in the default implementation. + ORT_UNUSED_PARAMETER(graph_viewer); + // Default implementation returns empty string + return std::string(); + } + + /** + * Validate the compatibility of a compiled model with this execution provider. + */ + virtual common::Status ValidateCompiledModelCompatibilityInfo(const std::string& /*compatibility_info*/, + OrtCompiledModelCompatibility& model_compatibility) const { + // Default implementation indicates this EP does not support model compatibility validation + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return Status::OK(); + } + #endif void SetLogger(const logging::Logger* logger) { diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index f0992f05f31e5..672103bedc437 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -8,3 +8,8 @@ // Key for the execution provider version string. This should be available for all plugin EPs. static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; + +// Prefix for execution provider compatibility information stored in model metadata. +// Used when generating EP context models to store compatibility strings for each EP. +// Full key format: "ep_compatibility_info." +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; \ No newline at end of file diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 314cf76cc8044..7eb5f7659a365 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -382,8 +382,8 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio // THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME // Meant to be used with SetEpDynamicOptions // Specify the type of workload for this session. -// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] -// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. +// "Default": OS determines the scheduling priority and processor performance to service this workload. [Default] +// "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; // Disables model compilation during session initialization. @@ -401,3 +401,10 @@ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload // - "0": EP compile is not disabled. [DEFAULT] // - "1": EP compile is disabled. static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile"; + +// Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION. +// "0": Allow execution with suboptimal performance. [DEFAULT] +// "1": Fail session creation to require recompilation for optimal performance. +// Note: UNSUPPORTED models always fail regardless of this setting. +static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = + "session.fail_on_suboptimal_compiled_model"; diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index efc12ef8dd0e8..421e5a6db51b7 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -22,6 +22,7 @@ #include "core/graph/model.h" #include "core/graph/model_saving_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -909,6 +910,34 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } + // Generate EP compatibility strings for OrtEp types and add to model metadata + // At this point, the graph has been populated with all the EPContext nodes + { + ORT_RETURN_IF_ERROR(ep_graph.Resolve()); + const GraphViewer graph_viewer(ep_graph); + for (const auto& ep : execution_providers) { + try { + // Generate the compatibility string for this EP + std::string compatibility_string = ep->GetCompiledModelCompatibilityInfo(graph_viewer); + if (!compatibility_string.empty()) { + // Create a unique key for this EP's compatibility info + // Use format: "ep_compatibility_info." + // All EPs in a session must have a unique Type() value, so this will be unique for the generated model + std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep->Type(); + auto& model_metadata = ep_context_model.MetaData(); + auto [it, was_inserted] = + model_metadata.insert_or_assign(metadata_key, compatibility_string); + if (!was_inserted) { + LOGS(logger, WARNING) << "Overwriting existing EP compatibility info for key: " << metadata_key << " (EP: " << ep->Type() << ")"; + } + LOGS(logger, VERBOSE) << "Added EP compatibility info for " << ep->Type() << " with key: " << metadata_key; + } + } catch (const std::exception& ex) { + LOGS(logger, WARNING) << "Failed to generate compatibility string for EP " << ep->Type() << ": " << ex.what(); + } + } + } + size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold; std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path; bool force_embed_external_ini = false; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 436af7115eb1a..eb5e1e89e2f9c 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -361,6 +361,10 @@ const ModelMetaData& Model::MetaData() const noexcept { return model_metadata_; } +ModelMetaData& Model::MetaData() noexcept { + return model_metadata_; +} + Graph& Model::MainGraph() noexcept { return *graph_; } @@ -377,6 +381,15 @@ ModelProto Model::ToProto() const { // out dense duplicates of sparse initializers and leave the original // proto intact. ModelProto result(model_proto_); + + // Sync current model_metadata_ back to protobuf metadata_props + result.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{result.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProto(); return result; @@ -386,6 +399,15 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa const std::filesystem::path& file_path, const ModelSavingOptions& model_saving_options) const { ModelProto result(model_proto_); + + // Sync current model_metadata_ back to protobuf metadata_props + result.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{result.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, file_path, diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 70f82bcfb160b..e8722f6f5c0b2 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -189,6 +189,8 @@ class Model { const ModelMetaData& MetaData() const noexcept; + ModelMetaData& MetaData() noexcept; + // Gets the path from which the model was loaded, if any. const std::filesystem::path& ModelPath() const noexcept { return model_path_; } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 3610b0f797a46..f3e30caf07e81 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -23,6 +23,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; + OrtEpFactory::ValidateCompiledModelCompatibilityInfo = Forward::ValidateCompiledModelCompatibilityInfo; OrtEpFactory::CreateAllocator = Forward::CreateAllocator; OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator; OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 0e34fef0ff74c..23e5e95af2903 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -80,6 +80,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } + OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index bd0b76b21511f..6c55730d83979 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -62,6 +62,14 @@ class EpFactoryInternalImpl { return false; } + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info, + _Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept { + ORT_UNUSED_PARAMETER(compatibility_info); + // Default implementation: mark as not applicable + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 2aac1e1c21cc7..3bfca62a4d011 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -644,4 +644,35 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost); } } + +std::string PluginExecutionProvider::GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const { + if (ort_ep_->GetCompiledModelCompatibilityInfo == nullptr) { + // Plugin EP did not provide an implementation of this function, so we call a default implementation. + return Base::GetCompiledModelCompatibilityInfo(graph_viewer); + } + std::unique_ptr ep_graph = nullptr; + auto ort_status = EpGraph::Create(graph_viewer, ep_graph); + if (!ort_status.IsOK()) { + LOGS(*GetLogger(), ERROR) << "Failed to create EpGraph: " << ort_status.ToString(); + return {}; + } + // Call EP plugin's OrtEp::GenerateCompiledModelCompatibilityInfo() function. + std::string compatibility_info_string; + compatibility_info_string = ort_ep_->GetCompiledModelCompatibilityInfo(ort_ep_.get(), ep_graph.get()); + return compatibility_info_string; +} + +Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const { + if (ep_factory_.ValidateCompiledModelCompatibilityInfo == nullptr) { + // Plugin EP did not provide an implementation of this function, so we call a default implementation. + return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + // Delegate to the EP factory's validation method + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_, + compatibility_info.c_str(), + &model_compatibility))); + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 728f959ad67cb..622bbb3f97b24 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -101,6 +101,11 @@ class PluginExecutionProvider : public IExecutionProvider { // needed based on matching against allocator_mem_infos_. std::vector CreatePreferredAllocators() override; + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; + + Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override; + private: struct FusedNodeState { FusedNodeState() = default; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 67b22779395ec..29793b503c9d1 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -45,6 +45,12 @@ struct ForwardToFactoryImpl { session_options, logger, ep); } + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + return static_cast(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility); + } + static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, _In_ const OrtMemoryInfo* memory_info, _In_opt_ const OrtKeyValuePairs* allocator_options, diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index f90ace95d6e58..d4041dfce5a7a 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -17,6 +17,7 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/session/plugin_ep/ep_factory_internal.h" @@ -206,6 +207,117 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, return CreateSessionAndLoadModelImpl(options, env->GetEnvironment(), model_path, model_data, model_data_length, sess); } +#if !defined(ORT_MINIMAL_BUILD) +static const char* GetCompatibilityStatusString(OrtCompiledModelCompatibility status) { + switch (status) { + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + return "SUPPORTED_OPTIMAL"; + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + return "SUPPORTED_PREFER_RECOMPILATION"; + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + return "UNSUPPORTED"; + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + return "NOT_APPLICABLE"; + default: + return "UNKNOWN"; + } +} + +static Status ValidateCompiledModelCompatibility(InferenceSession& sess) { + // Get model metadata + auto [status, model_metadata] = sess.GetModelMetadata(); + if (!status.IsOK() || !model_metadata) { + // No metadata available, skip validation + return Status::OK(); + } + + const auto& custom_metadata = model_metadata->custom_metadata_map; + if (custom_metadata.empty()) { + // No custom metadata available, skip validation + return Status::OK(); + } + + // Check if user wants to fail on suboptimal models + bool fail_on_suboptimal = sess.GetSessionOptions().config_options.GetConfigEntry( + kOrtSessionOptionsFailOnSuboptimalCompiledModel) == "1"; + + const auto& registered_provider_types = sess.GetRegisteredProviderTypes(); + + // Access the execution providers through the session state (available after Initialize) + const auto& execution_providers = sess.GetSessionState().GetExecutionProviders(); + + for (const auto& ep_type : registered_provider_types) { + // Construct the full metadata key using the prefix + EP type + const std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + auto metadata_it = custom_metadata.find(metadata_key); + if (metadata_it != custom_metadata.end()) { + const std::string& compatibility_info = metadata_it->second; + + // Get the actual EP instance to call validation + const IExecutionProvider* ep = execution_providers.Get(ep_type); + + if (ep != nullptr) { + // Call the EP's validation method (virtual method with default implementation) + OrtCompiledModelCompatibility compatibility_status; + Status validation_result = ep->ValidateCompiledModelCompatibilityInfo( + compatibility_info, compatibility_status); + + if (validation_result.IsOK()) { + // Log the compatibility status + const char* status_str = GetCompatibilityStatusString(compatibility_status); + LOGS(*sess.GetLogger(), INFO) + << "EP " << ep_type << " compiled model compatibility: " << status_str; + + // Enforce compatibility based on status + switch (compatibility_status) { + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + // Continue execution + break; + + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + // Always fail for unsupported models + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Compiled model is not supported by execution provider: " + ep_type); + + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + // Behavior depends on user setting + if (fail_on_suboptimal) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Compiled model is suboptimal for execution provider: " + ep_type + + ". Recompilation recommended for better performance."); + } + // Otherwise continue with warning + LOGS(*sess.GetLogger(), WARNING) + << "EP " << ep_type << " reports compiled model is supported but suboptimal. " + << "Consider recompiling for better performance."; + break; + + default: + // Handle any unknown status values + LOGS(*sess.GetLogger(), WARNING) + << "EP " << ep_type << " returned unknown compatibility status: " << compatibility_status; + break; + } + } else { + // Validation failed - this should cause session initialization to fail + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to validate compiled model compatibility for EP " + ep_type + + ": " + validation_result.ErrorMessage()); + } + } + } else { + // No compatibility info found for this EP - normal for non-compiled models + LOGS(*sess.GetLogger(), VERBOSE) + << "No compiled model compatibility info found for EP " << ep_type; + } + } + + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { @@ -253,6 +365,12 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); +#if !defined(ORT_MINIMAL_BUILD) + // Validate compiled model compatibility for all registered execution providers + // This must be done after Initialize() so the session state is available + ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess)); +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc new file mode 100644 index 0000000000000..be97cf2620881 --- /dev/null +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -0,0 +1,410 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/framework/execution_provider.h" +#include "core/framework/compute_capability.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/abi_session_options_impl.h" +#include "core/framework/error_code_helper.h" +#include "dummy_provider.h" +#include "test_utils.h" +#include "test/test_environment.h" +#include "test/providers/provider_test_utils.h" + +using namespace onnxruntime; +using namespace onnxruntime::test; + +namespace { + +// Test execution provider that extends IExecutionProvider with compatibility string functionality +class TestCompatibilityExecutionProvider : public IExecutionProvider { + public: + static constexpr const char* kTestCompatibilityExecutionProviderType = "TestCompatibilityExecutionProvider"; + + TestCompatibilityExecutionProvider() : IExecutionProvider(kTestCompatibilityExecutionProviderType) { + } + + std::shared_ptr GetKernelRegistry() const override { + return std::make_shared(); + } + + std::vector CreatePreferredAllocators() override { + return {}; + } + + // Configurable mock behavior + void SetMockCompatibilityString(const std::string& str) { + mock_compatibility_string_ = str; + } + + void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) { + mock_compatibility_status_ = status; + } + + void SetShouldFailValidation(bool should_fail) { + should_fail_validation_ = should_fail; + } + + // Override compatibility methods + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override { + ORT_UNUSED_PARAMETER(graph_viewer); + return mock_compatibility_string_; + } + + common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override { + if (should_fail_validation_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mock validation failure"); + } + + // Simple validation logic for testing + // If the mock status is explicitly set to NOT_APPLICABLE, always return that + if (mock_compatibility_status_ == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + } else if (compatibility_info.empty()) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + } else if (compatibility_info == mock_compatibility_string_) { + model_compatibility = mock_compatibility_status_; + } else { + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + } + + return Status::OK(); + } + + private: + std::string mock_compatibility_string_ = "default_test_compatibility_v1.0"; + OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + bool should_fail_validation_ = false; +}; + +// Helper class to create test models +class ModelBuilderWithCompatibility { + public: + static std::unique_ptr CreateSimpleTestModel() { + // Create a simple model with a single Add operation + std::unordered_map domain_to_version; + domain_to_version[onnxruntime::kOnnxDomain] = 7; + + auto p_model = std::make_unique("test_model", true, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), + DefaultLoggingManager().DefaultLogger()); + + onnxruntime::Graph& graph = p_model->MainGraph(); + + // Define tensor type + ONNX_NAMESPACE::TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + tensor_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + // Create input and output node args + auto& input_arg_a = graph.GetOrCreateNodeArg("A", &tensor_float); + auto& input_arg_b = graph.GetOrCreateNodeArg("B", &tensor_float); + auto& output_arg = graph.GetOrCreateNodeArg("C", &tensor_float); + + // Create Add node + std::vector input_defs = {&input_arg_a, &input_arg_b}; + std::vector output_defs = {&output_arg}; + graph.AddNode("add_node", "Add", "Add two tensors", input_defs, output_defs, nullptr, onnxruntime::kOnnxDomain); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + return p_model; + } + + static std::unique_ptr CreateModelWithCompatibilityMetadata( + const std::map& ep_compatibility_info) { + auto model = CreateSimpleTestModel(); + + // Add compatibility metadata + auto& metadata = model->MetaData(); + for (const auto& [ep_type, compatibility_string] : ep_compatibility_info) { + std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + metadata[metadata_key] = compatibility_string; + } + + return model; + } +}; + +// Helper class to create test sessions +class SessionBuilderWithCompatibility { + public: + static std::unique_ptr CreateTestSession(std::unique_ptr model, bool fail_on_suboptimal = false) { + SessionOptions so; + so.session_logid = "EpCompatibilityTest"; + so.session_log_verbosity_level = 1; + + if (fail_on_suboptimal) { + EXPECT_TRUE(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "1").IsOK()); + } + + // Convert Model to ModelProto and serialize + auto model_proto = model->ToProto(); + std::string model_data; + EXPECT_TRUE(model_proto.SerializeToString(&model_data)); + std::stringstream model_stream(model_data); + + // Create session with basic constructor + auto session = std::make_unique(so, GetEnvironment()); + + // Load the model from the stream and validate the status + auto load_status = session->Load(model_stream); + EXPECT_TRUE(load_status.IsOK()) << "Failed to load model: " << load_status.ErrorMessage(); + + return session; + } +}; + +// Helper function to initialize session using the proper validation pathway +Status InitializeSessionWithValidation(InferenceSession& session) { + // Create OrtSessionOptions from the session's SessionOptions to use the proper initialization path + OrtSessionOptions ort_session_options; + ort_session_options.value = session.GetSessionOptions(); + + // Call the InitializeSession function from utils.cc which includes validation + OrtStatus* ort_status = InitializeSession(&ort_session_options, session, nullptr); + + // Convert OrtStatus to Status using the proper helper function + return ToStatusAndRelease(ort_status); +} + +} // anonymous namespace + +class EpCompatibilityTest : public ::testing::Test { + protected: + void SetUp() override { + test_model_ = ModelBuilderWithCompatibility::CreateSimpleTestModel(); + } + + protected: + std::unique_ptr test_model_; +}; + +// Test basic compatibility string generation during compilation +TEST_F(EpCompatibilityTest, TestCompatibilityStringGeneration) { + const std::string expected_compatibility_string = "test_ep_v1.0_compatibility_data"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(expected_compatibility_string); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); + + // Note: In the actual implementation, we would need to trigger EP context model creation + // to see the compatibility strings stored. For now, this tests that the methods are called + // without error during session initialization. +} + +// Test compatibility string storage in model metadata +TEST_F(EpCompatibilityTest, TestCompatibilityStringStorage) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string expected_compatibility_string = "stored_compatibility_v2.0"; + + // Create model with pre-populated compatibility metadata + std::map compatibility_info = { + {ep_type, expected_compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + // Verify metadata was stored correctly + const auto& metadata = model_with_metadata->MetaData(); + std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + auto it = metadata.find(expected_key); + ASSERT_NE(it, metadata.end()) << "Expected compatibility metadata key not found: " << expected_key; + EXPECT_EQ(it->second, expected_compatibility_string); +} + +// Test multiple EPs generating different compatibility strings +TEST_F(EpCompatibilityTest, TestMultipleEpCompatibilityStrings) { + std::map compatibility_info = { + {"EP_A", "ep_a_compatibility_v1.0"}, + {"EP_B", "ep_b_compatibility_v2.1"}, + {"EP_C", "ep_c_compatibility_v1.5"}}; + + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + // Verify all compatibility strings are stored + const auto& metadata = model_with_metadata->MetaData(); + for (const auto& [ep_type, expected_string] : compatibility_info) { + std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + auto it = metadata.find(expected_key); + ASSERT_NE(it, metadata.end()) << "Expected compatibility metadata key not found: " << expected_key; + EXPECT_EQ(it->second, expected_string); + } +} + +// Test empty compatibility string handling +TEST_F(EpCompatibilityTest, TestEmptyCompatibilityString) { + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(""); // Empty string + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed even with empty compatibility string +} + +// Test compatibility validation with optimal status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Optimal) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "optimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL); + + // Create model with matching compatibility metadata + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed with optimal compatibility +} + +// Test compatibility validation with suboptimal status (default session settings) +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Suboptimal_DefaultSettings) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "suboptimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION); + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), false); // Don't fail on suboptimal + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed by default with suboptimal compatibility +} + +// Test compatibility validation with suboptimal status (fail on suboptimal enabled) +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Suboptimal_FailEnabled) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "suboptimal_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION); + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), true); // Fail on suboptimal + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should fail during initialization due to suboptimal compatibility + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("suboptimal")); +} + +// Test compatibility validation with unsupported status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_Unsupported) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string stored_compatibility_string = "old_compatibility_v1.0"; + const std::string current_compatibility_string = "new_compatibility_v2.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(current_compatibility_string); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED); + + // Model has old compatibility string, EP has new one -> unsupported + std::map compatibility_info = {{ep_type, stored_compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata), false); // Even with fail_on_suboptimal=false + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should fail during initialization due to unsupported compatibility + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported")); +} + +// Test compatibility validation with not applicable status +TEST_F(EpCompatibilityTest, TestCompatibilityValidation_NotApplicable) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(""); // Empty compatibility string + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); + + // Model has some compatibility string, but EP returns not applicable + std::map compatibility_info = {{ep_type, "some_compatibility_string"}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed with not applicable status +} + +// Test missing compatibility info in model metadata +TEST_F(EpCompatibilityTest, TestMissingCompatibilityInfo) { + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString("some_compatibility_string"); + + // Use model without any compatibility metadata + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(test_model_)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); // Should succeed when no compatibility info is present +} + +// Test EP validation failure +TEST_F(EpCompatibilityTest, TestEpValidationFailure) { + const std::string ep_type = "TestCompatibilityExecutionProvider"; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityString(compatibility_string); + test_ep->SetShouldFailValidation(true); // Force validation failure + + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Should handle EP validation failure gracefully + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure")); +} + +// Test session option configuration for fail on suboptimal +TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { + SessionOptions so; + + // Test default value + std::string config_value; + bool has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_FALSE(has_config); // Should not be set by default + + // Test setting the option + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "1")); + has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_TRUE(has_config); + EXPECT_EQ(config_value, "1"); + + // Test setting to disabled + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, "0")); + has_config = so.config_options.TryGetConfigEntry(kOrtSessionOptionsFailOnSuboptimalCompiledModel, config_value); + EXPECT_TRUE(has_config); + EXPECT_EQ(config_value, "0"); +}