Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1caed4e
Adding Get plumbing and supporting infrastructure to call it as part …
Jul 31, 2025
6543edd
Validation stubs for EP factory
Aug 1, 2025
ac1a64c
Add validation method to IExecutionProvider
Aug 1, 2025
3a9e9c0
Attempt at session enfrocement
Aug 5, 2025
08e47fc
Add session option for precompiled-but-suboptimal and enforce in sess…
Aug 5, 2025
0d7febe
Adding unit test cases
Aug 5, 2025
770aece
WIP but tests still broken
Aug 9, 2025
407d0a2
Merge branch 'main' into adrastogi/model-compat
Aug 9, 2025
2603892
Back to where we were
Aug 9, 2025
b022a10
Tests passing
Aug 10, 2025
998da90
Update include/onnxruntime/core/framework/execution_provider.h
adrastogi Aug 14, 2025
297d074
Update onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
adrastogi Aug 14, 2025
6544617
Lintrunner local run
Aug 14, 2025
741f387
Merge branch 'adrastogi/model-compat' of https://github.com/Microsoft…
Aug 14, 2025
e3551ce
Merge branch 'main' into adrastogi/model-compat
Aug 19, 2025
5ac17d7
PR feedback (add guards to fix CI build error for minimal builds)
Aug 19, 2025
ed2a54f
Ensure GetCompatibilityStatusString isn't used in minimal builds
Aug 19, 2025
52f3884
PR feedback (fix placement of compatibility string generation, loggin…
Aug 19, 2025
0e3ee2e
PR feedback (ensure Validate method is valid, else fall back to base …
Aug 20, 2025
07a3b85
Update onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfac…
adrastogi Aug 20, 2025
bce4093
Merge branch 'main' into adrastogi/model-compat
Aug 21, 2025
009f87b
PR feedback (initialize GraphViewer with ep_graph (*not* graph), and …
Aug 22, 2025
7b73392
PR feedback (early return during validation if there is no custom met…
Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class GraphOptimizerRegistry;
#include "core/framework/framework_provider_common.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"
#include "core/session/onnxruntime_c_api.h"

struct OrtEpDevice;
struct OrtRunOptions;
Expand Down Expand Up @@ -322,6 +323,29 @@ class IExecutionProvider {
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs);

/**
* Get the compatibility info for a compiled model.
*
* The execution provider determines this value, which denotes the compatibility of the compiled model with the EP.
* This is stored in the model metadata under a key associated with the EP type.
*/
virtual std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const {
// graph_viewer and model_metadata are not used in the default implementation.
ORT_UNUSED_PARAMETER(graph_viewer);
// Default implementation returns empty string
return std::string();
}

/**
* Validate the compatibility of a compiled model with this execution provider.
*/
virtual common::Status ValidateCompiledModelCompatibilityInfo(const std::string& /*compatibility_info*/,
OrtCompiledModelCompatibility& model_compatibility) const {
// Default implementation indicates this EP does not support model compatibility validation
model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
return Status::OK();
}

#endif

void SetLogger(const logging::Logger* logger) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@

// Key for the execution provider version string. This should be available for all plugin EPs.
static const char* const kOrtEpDevice_EpMetadataKey_Version = "version";

// Prefix for execution provider compatibility information stored in model metadata.
// Used when generating EP context models to store compatibility strings for each EP.
// Full key format: "ep_compatibility_info.<EP_TYPE>"
static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info.";
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio
// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
// Meant to be used with SetEpDynamicOptions
// Specify the type of workload for this session.
// Default: OS determines the scheduling priority and processor performance to service this workload. [Default]
// Efficient: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
// "Default": OS determines the scheduling priority and processor performance to service this workload. [Default]
// "Efficient": OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";

// Disables model compilation during session initialization.
Expand All @@ -401,3 +401,10 @@ static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload
// - "0": EP compile is not disabled. [DEFAULT]
// - "1": EP compile is disabled.
static const char* const kOrtSessionOptionsDisableModelCompile = "session.disable_model_compile";

// Controls behavior when compiled model compatibility is SUPPORTED_PREFER_RECOMPILATION.
// "0": Allow execution with suboptimal performance. [DEFAULT]
// "1": Fail session creation to require recompilation for optimal performance.
// Note: UNSUPPORTED models always fail regardless of this setting.
static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel =
"session.fail_on_suboptimal_compiled_model";
28 changes: 28 additions & 0 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "core/graph/model.h"
#include "core/graph/model_saving_options.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"

// uncomment this line to count non-CUDA ops in ONNX domain
// #define COUNT_NON_CUDA_OPS
Expand Down Expand Up @@ -909,6 +910,33 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
}
}

// Generate EP compatibility strings for OrtEp types and add to model metadata
// At this point, the graph has been populated with all the EPContext nodes
{
const GraphViewer graph_viewer(graph);
for (const auto& ep : execution_providers) {
try {
// Generate the compatibility string for this EP
std::string compatibility_string = ep->GetCompiledModelCompatibilityInfo(graph_viewer);
if (!compatibility_string.empty()) {
// Create a unique key for this EP's compatibility info
// Use format: "ep_compatibility_info.<EP_TYPE>"
// All EPs in a session must have a unique Type() value, so this will be unique for the generated model
std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep->Type();
auto& model_metadata = ep_context_model.MetaData();
auto [it, was_inserted] =
model_metadata.insert_or_assign(metadata_key, compatibility_string);
if (!was_inserted) {
LOGS(logger, WARNING) << "Overwriting existing EP compatibility info for key: " << metadata_key << " (EP: " << ep->Type() << ")";
}
LOGS(logger, VERBOSE) << "Added EP compatibility info for " << ep->Type() << " with key: " << metadata_key;
}
} catch (const std::exception& ex) {
LOGS(logger, WARNING) << "Failed to generate compatibility string for EP " << ep->Type() << ": " << ex.what();
}
}
}

size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold;
std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path;
bool force_embed_external_ini = false;
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,10 @@ const ModelMetaData& Model::MetaData() const noexcept {
return model_metadata_;
}

ModelMetaData& Model::MetaData() noexcept {
return model_metadata_;
}

Graph& Model::MainGraph() noexcept {
return *graph_;
}
Expand All @@ -377,6 +381,15 @@ ModelProto Model::ToProto() const {
// out dense duplicates of sparse initializers and leave the original
// proto intact.
ModelProto result(model_proto_);

// Sync current model_metadata_ back to protobuf metadata_props
result.clear_metadata_props();
for (const auto& metadata : model_metadata_) {
const gsl::not_null<StringStringEntryProto*> prop{result.add_metadata_props()};
prop->set_key(metadata.first);
prop->set_value(metadata.second);
}

const auto& graph = *graph_;
*(result.mutable_graph()) = graph.ToGraphProto();
return result;
Expand All @@ -386,6 +399,15 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa
const std::filesystem::path& file_path,
const ModelSavingOptions& model_saving_options) const {
ModelProto result(model_proto_);

// Sync current model_metadata_ back to protobuf metadata_props
result.clear_metadata_props();
for (const auto& metadata : model_metadata_) {
const gsl::not_null<StringStringEntryProto*> prop{result.add_metadata_props()};
prop->set_key(metadata.first);
prop->set_value(metadata.second);
}

const auto& graph = *graph_;
*(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name,
file_path,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class Model {

const ModelMetaData& MetaData() const noexcept;

ModelMetaData& MetaData() noexcept;

// Gets the path from which the model was loaded, if any.
const std::filesystem::path& ModelPath() const noexcept { return model_path_; }

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr<EpFactoryInternalImpl> impl
OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices;
OrtEpFactory::CreateEp = Forward::CreateEp;
OrtEpFactory::ReleaseEp = Forward::ReleaseEp;
OrtEpFactory::ValidateCompiledModelCompatibilityInfo = Forward::ValidateCompiledModelCompatibilityInfo;
OrtEpFactory::CreateAllocator = Forward::CreateAllocator;
OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator;
OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer;
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class EpFactoryInternal : public OrtEpFactory {
return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream);
}

OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
return impl_->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
}

// Function ORT calls to release an EP instance.
void ReleaseEp(OrtEp* /*ep*/) noexcept {
// we never create an OrtEp so we should never be trying to release one
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class EpFactoryInternalImpl {
return false;
}

virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* model_compatibility) noexcept {
ORT_UNUSED_PARAMETER(compatibility_info);
// Default implementation: mark as not applicable
*model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
return nullptr;
}

virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/,
_In_opt_ const OrtKeyValuePairs* /*stream_options*/,
_Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,4 +644,35 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr
registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost);
}
}

std::string PluginExecutionProvider::GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const {
if (ort_ep_->GetCompiledModelCompatibilityInfo == nullptr) {
// Plugin EP did not provide an implementation of this function, so we call a default implementation.
return Base::GetCompiledModelCompatibilityInfo(graph_viewer);
}
std::unique_ptr<EpGraph> ep_graph = nullptr;
auto ort_status = EpGraph::Create(graph_viewer, ep_graph);
if (!ort_status.IsOK()) {
LOGS(*GetLogger(), ERROR) << "Failed to create EpGraph: " << ort_status.ToString();
return {};
}
// Call EP plugin's OrtEp::GenerateCompiledModelCompatibilityInfo() function.
std::string compatibility_info_string;
compatibility_info_string = ort_ep_->GetCompiledModelCompatibilityInfo(ort_ep_.get(), ep_graph.get());
return compatibility_info_string;
}

Status PluginExecutionProvider::ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const {
if (ep_factory_.ValidateCompiledModelCompatibilityInfo == nullptr) {
// Plugin EP did not provide an implementation of this function, so we call a default implementation.
return Base::ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
}
// Delegate to the EP factory's validation method
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory_.ValidateCompiledModelCompatibilityInfo(&ep_factory_,
compatibility_info.c_str(),
&model_compatibility)));
return Status::OK();
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class PluginExecutionProvider : public IExecutionProvider {
// needed based on matching against allocator_mem_infos_.
std::vector<AllocatorPtr> CreatePreferredAllocators() override;

std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override;

Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info,
OrtCompiledModelCompatibility& model_compatibility) const override;

private:
struct FusedNodeState {
FusedNodeState() = default;
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ struct ForwardToFactoryImpl {
session_options, logger, ep);
}

static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfo(OrtEpFactory* this_ptr,
const char* compatibility_info,
OrtCompiledModelCompatibility* model_compatibility) noexcept {
return static_cast<TFactory*>(this_ptr)->ValidateCompiledModelCompatibilityInfo(compatibility_info, model_compatibility);
}

static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr,
_In_ const OrtMemoryInfo* memory_info,
_In_opt_ const OrtKeyValuePairs* allocator_options,
Expand Down
113 changes: 113 additions & 0 deletions onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/ort_apis.h"
#include "core/session/ort_env.h"
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/session/plugin_ep/ep_factory_internal.h"
Expand Down Expand Up @@ -206,6 +207,112 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,
return CreateSessionAndLoadModelImpl(options, env->GetEnvironment(), model_path, model_data, model_data_length, sess);
}

#if !defined(ORT_MINIMAL_BUILD)
static const char* GetCompatibilityStatusString(OrtCompiledModelCompatibility status) {
switch (status) {
case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL:
return "SUPPORTED_OPTIMAL";
case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION:
return "SUPPORTED_PREFER_RECOMPILATION";
case OrtCompiledModelCompatibility_EP_UNSUPPORTED:
return "UNSUPPORTED";
case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE:
return "NOT_APPLICABLE";
default:
return "UNKNOWN";
}
}

static Status ValidateCompiledModelCompatibility(InferenceSession& sess) {
// Get model metadata
auto [status, model_metadata] = sess.GetModelMetadata();
if (!status.IsOK() || !model_metadata) {
// No metadata available, skip validation
return Status::OK();
}

// Check if user wants to fail on suboptimal models
bool fail_on_suboptimal = sess.GetSessionOptions().config_options.GetConfigEntry(
kOrtSessionOptionsFailOnSuboptimalCompiledModel) == "1";

const auto& custom_metadata = model_metadata->custom_metadata_map;
const auto& registered_provider_types = sess.GetRegisteredProviderTypes();

// Access the execution providers through the session state (available after Initialize)
const auto& execution_providers = sess.GetSessionState().GetExecutionProviders();

for (const auto& ep_type : registered_provider_types) {
// Construct the full metadata key using the prefix + EP type
const std::string metadata_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type;

auto metadata_it = custom_metadata.find(metadata_key);
if (metadata_it != custom_metadata.end()) {
const std::string& compatibility_info = metadata_it->second;

// Get the actual EP instance to call validation
const IExecutionProvider* ep = execution_providers.Get(ep_type);

if (ep != nullptr) {
// Call the EP's validation method (virtual method with default implementation)
OrtCompiledModelCompatibility compatibility_status;
Status validation_result = ep->ValidateCompiledModelCompatibilityInfo(
compatibility_info, compatibility_status);

if (validation_result.IsOK()) {
// Log the compatibility status
const char* status_str = GetCompatibilityStatusString(compatibility_status);
LOGS(*sess.GetLogger(), INFO)
<< "EP " << ep_type << " compiled model compatibility: " << status_str;

// Enforce compatibility based on status
switch (compatibility_status) {
case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE:
case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL:
// Continue execution
break;

case OrtCompiledModelCompatibility_EP_UNSUPPORTED:
// Always fail for unsupported models
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Compiled model is not supported by execution provider: " + ep_type);

case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION:
// Behavior depends on user setting
if (fail_on_suboptimal) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Compiled model is suboptimal for execution provider: " + ep_type +
". Recompilation recommended for better performance.");
}
// Otherwise continue with warning
LOGS(*sess.GetLogger(), WARNING)
<< "EP " << ep_type << " reports compiled model is supported but suboptimal. "
<< "Consider recompiling for better performance.";
break;

default:
// Handle any unknown status values
LOGS(*sess.GetLogger(), WARNING)
<< "EP " << ep_type << " returned unknown compatibility status: " << compatibility_status;
break;
}
} else {
// Validation failed - this should cause session initialization to fail
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Failed to validate compiled model compatibility for EP " + ep_type +
": " + validation_result.ErrorMessage());
}
}
} else {
// No compatibility info found for this EP - normal for non-compiled models
LOGS(*sess.GetLogger(), VERBOSE)
<< "No compiled model compatibility info found for EP " << ep_type;
}
}

return Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD)

OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options,
_In_ onnxruntime::InferenceSession& sess,
_Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) {
Expand Down Expand Up @@ -253,6 +360,12 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options,

ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize());

#if !defined(ORT_MINIMAL_BUILD)
// Validate compiled model compatibility for all registered execution providers
// This must be done after Initialize() so the session state is available
ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess));
#endif // !defined(ORT_MINIMAL_BUILD)

return nullptr;
}

Expand Down
Loading
Loading