Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,88 @@ public struct OrtApi
public IntPtr EpDevice_Device;
public IntPtr GetEpApi;
public IntPtr GetTensorSizeInBytes;

public IntPtr AllocatorGetStats;

public IntPtr CreateMemoryInfo_V2;
public IntPtr MemoryInfoGetDeviceMemType;
public IntPtr MemoryInfoGetVendorId;

public IntPtr ValueInfo_GetValueProducer;
public IntPtr ValueInfo_GetValueNumConsumers;
public IntPtr ValueInfo_GetValueConsumers;
public IntPtr ValueInfo_GetInitializerValue;
public IntPtr ValueInfo_GetExternalInitializerInfo;
public IntPtr ValueInfo_IsRequiredGraphInput;
public IntPtr ValueInfo_IsOptionalGraphInput;
public IntPtr ValueInfo_IsGraphOutput;
public IntPtr ValueInfo_IsConstantInitializer;
public IntPtr ValueInfo_IsFromOuterScope;
public IntPtr Graph_GetName;
public IntPtr Graph_GetModelPath;
public IntPtr Graph_GetOnnxIRVersion;
public IntPtr Graph_GetNumOperatorSets;
public IntPtr Graph_GetOperatorSets;
public IntPtr Graph_GetNumInputs;
public IntPtr Graph_GetInputs;
public IntPtr Graph_GetNumOutputs;
public IntPtr Graph_GetOutputs;
public IntPtr Graph_GetNumInitializers;
public IntPtr Graph_GetInitializers;
public IntPtr Graph_GetNumNodes;
public IntPtr Graph_GetNodes;
public IntPtr Graph_GetParentNode;
public IntPtr Graph_GetGraphView;
public IntPtr Node_GetId;
public IntPtr Node_GetName;
public IntPtr Node_GetOperatorType;
public IntPtr Node_GetDomain;
public IntPtr Node_GetSinceVersion;
public IntPtr Node_GetNumInputs;
public IntPtr Node_GetInputs;
public IntPtr Node_GetNumOutputs;
public IntPtr Node_GetOutputs;
public IntPtr Node_GetNumImplicitInputs;
public IntPtr Node_GetImplicitInputs;
public IntPtr Node_GetNumAttributes;
public IntPtr Node_GetAttributes;
public IntPtr Node_GetAttributeByName;
public IntPtr Node_GetTensorAttributeAsOrtValue;
public IntPtr OpAttr_GetType;
public IntPtr OpAttr_GetName;
public IntPtr Node_GetNumSubgraphs;
public IntPtr Node_GetSubgraphs;
public IntPtr Node_GetGraph;
public IntPtr Node_GetEpName;
public IntPtr ReleaseExternalInitializerInfo;
public IntPtr ExternalInitializerInfo_GetFilePath;
public IntPtr ExternalInitializerInfo_GetFileOffset;
public IntPtr ExternalInitializerInfo_GetByteSize;

public IntPtr GetRunConfigEntry;

public IntPtr EpDevice_MemoryInfo;

public IntPtr CreateSharedAllocator;
public IntPtr GetSharedAllocator;
public IntPtr ReleaseSharedAllocator;

public IntPtr GetTensorData;

public IntPtr GetSessionOptionsConfigEntries;

public IntPtr SessionGetMemoryInfoForInputs;
public IntPtr SessionGetMemoryInfoForOutputs;
public IntPtr SessionGetEpDeviceForInputs;

public IntPtr CreateSyncStreamForEpDevice;
public IntPtr SyncStream_GetHandle;
public IntPtr ReleaseSyncStream;

public IntPtr CopyTensors;

public IntPtr Graph_GetModelMetadata;
public IntPtr GetModelCompatibilityForEpDevices;
}

internal static class NativeMethods
Expand Down Expand Up @@ -704,6 +786,10 @@ static NativeMethods()
(DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer(
api_.SessionOptionsSetEpSelectionPolicyDelegate,
typeof(DSessionOptionsSetEpSelectionPolicyDelegate));

OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer(
api_.GetModelCompatibilityForEpDevices,
typeof(DOrtGetModelCompatibilityForEpDevices));
}

internal class NativeLib
Expand Down Expand Up @@ -2456,6 +2542,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,

public static DOrtGetEpDevices OrtGetEpDevices;

/// <summary>
/// Validate compiled model compatibility for the provided EP devices.
/// </summary>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtGetModelCompatibilityForEpDevices(
IntPtr[] /* const OrtEpDevice* const* */ ep_devices,
UIntPtr /* size_t */ num_ep_devices,
byte[] /* const char* */ compatibility_info,
out int /* OrtCompiledModelCompatibility */ out_status);

public static DOrtGetModelCompatibilityForEpDevices OrtGetModelCompatibilityForEpDevices;

/// <summary>
/// Add execution provider devices to the session options.
/// Priority is based on the order of the OrtEpDevice instances. Highest priority first.
Expand Down
37 changes: 37 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Represents the compatibility status of a pre-compiled model with one or more execution provider devices.
/// </summary>
/// <remarks>
/// This enum is used to determine whether a pre-compiled model can be used with specific execution providers
/// and devices, or if recompilation is needed.
/// </remarks>
public enum OrtCompiledModelCompatibility
{
EP_NOT_APPLICABLE = 0,
EP_SUPPORTED_OPTIMAL = 1,
EP_SUPPORTED_PREFER_RECOMPILATION = 2,
EP_UNSUPPORTED = 3,
}

/// <summary>
/// Delegate for logging function callback.
/// Supply your function and register it with the environment to receive logging callbacks via
Expand Down Expand Up @@ -361,6 +376,28 @@ public string[] GetAvailableProviders()
}
}

/// <summary>
/// Validate a compiled model's compatibility information for one or more EP devices.
/// </summary>
/// <param name="epDevices">The list of EP devices to validate against.</param>
/// <param name="compatibilityInfo">The opaque compatibility information string from the precompiled model to validate.</param>
/// <returns>OrtCompiledModelCompatibility enum value denoting the compatibility status</returns>
public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(IReadOnlyList<OrtEpDevice> epDevices, string compatibilityInfo)
{
if (epDevices == null || epDevices.Count == 0)
throw new ArgumentException("epDevices must be non-empty", nameof(epDevices));

var devicePtrs = new IntPtr[epDevices.Count];
for (int i = 0; i < epDevices.Count; ++i)
{
devicePtrs[i] = epDevices[i].Handle;
}

var infoUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(compatibilityInfo);
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetModelCompatibilityForEpDevices(devicePtrs, (UIntPtr)devicePtrs.Length, infoUtf8, out int status));
return (OrtCompiledModelCompatibility)status;
}


/// <summary>
/// Get/Set log level property of OrtEnv instance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// not supported on mobile platforms
#if !(ANDROID || IOS)

namespace Microsoft.ML.OnnxRuntime.Tests;

using System;
using System.Linq;
using Xunit;
using System.Collections.Generic;

public class EpCompatibilityTests
{
private readonly OrtEnv ortEnvInstance = OrtEnv.Instance();

private IReadOnlyList<OrtEpDevice> GetDevices()
{
var epDevices = ortEnvInstance.GetEpDevices();
Assert.NotNull(epDevices);
Assert.NotEmpty(epDevices);
return epDevices;
}

[Fact]
public void GetEpCompatibility_InvalidArgs()
{
Assert.Throws<ArgumentException>(() => ortEnvInstance.GetModelCompatibilityForEpDevices(null, "info"));
Assert.Throws<ArgumentException>(() => ortEnvInstance.GetModelCompatibilityForEpDevices(new List<OrtEpDevice>(), "info"));
}

[Fact]
public void GetEpCompatibility_SingleDeviceCpuProvider()
{
var devices = GetDevices();
var someInfo = "arbitrary-compat-string";

// Use CPU device
var cpu = devices.First(d => d.EpName == "CPUExecutionProvider");
Assert.NotNull(cpu);
var selected = new List<OrtEpDevice> { cpu };
var status = ortEnvInstance.GetModelCompatibilityForEpDevices(selected, someInfo);

// CPU defaults to not applicable in this scenario
Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status);
}
}
#endif
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,14 @@ struct EpDevice : detail::EpDeviceImpl<OrtEpDevice> {
ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {});
};

/** \brief Validate a compiled model's compatibility for one or more EP devices.
*
* Throws on error. Returns the resulting compatibility status.
*/
OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
const std::vector<ConstEpDevice>& ep_devices,
const char* compatibility_info);

/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.
Expand Down
20 changes: 20 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,26 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
}

inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
const std::vector<ConstEpDevice>& ep_devices,
const char* compatibility_info) {
if (ep_devices.empty()) {
ORT_CXX_API_THROW("ep_devices is empty", ORT_INVALID_ARGUMENT);
}

std::vector<const OrtEpDevice*> ptrs;
ptrs.reserve(ep_devices.size());
for (const auto& d : ep_devices) ptrs.push_back(d);

OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
ThrowOnError(GetApi().GetModelCompatibilityForEpDevices(
reinterpret_cast<const OrtEpDevice* const*>(ptrs.data()),
ptrs.size(),
compatibility_info,
&status));
return status;
}

inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
OrtAllocator* allocator) {
OrtLoraAdapter* p;
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,17 @@ void addGlobalMethods(py::module& m) {
R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc",
py::return_value_policy::reference);

m.def(
"get_model_compatibility_for_ep_devices",
[](const std::vector<const OrtEpDevice*>& ep_devices,
const std::string& compatibility_info) -> OrtCompiledModelCompatibility {
OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
Ort::ThrowOnError(Ort::GetApi().GetModelCompatibilityForEpDevices(
ep_devices.data(), ep_devices.size(), compatibility_info.c_str(), &status));
return status;
},
R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc");

#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
m.def(
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
Expand Down Expand Up @@ -1759,6 +1770,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED)
.value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT);

py::enum_<OrtCompiledModelCompatibility>(m, "OrtCompiledModelCompatibility")
.value("EP_NOT_APPLICABLE", OrtCompiledModelCompatibility_EP_NOT_APPLICABLE)
.value("EP_SUPPORTED_OPTIMAL", OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL)
.value("EP_SUPPORTED_PREFER_RECOMPILATION", OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION)
.value("EP_UNSUPPORTED", OrtCompiledModelCompatibility_EP_UNSUPPORTED);

py::enum_<OrtAllocatorType>(m, "OrtAllocatorType")
.value("INVALID", OrtInvalidAllocator)
.value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator)
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/test/framework/ep_compatibility_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
#include "core/session/utils.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/abi_session_options_impl.h"
#include "core/framework/error_code_helper.h"
#include "dummy_provider.h"
Expand Down Expand Up @@ -499,3 +500,31 @@ TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) {

api->ReleaseEnv(env);
}

// -----------------------------
// C++ API unit tests
// -----------------------------

TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) {
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpCompatCxx"};
auto devices = env.GetEpDevices();
ASSERT_FALSE(devices.empty());

std::vector<Ort::ConstEpDevice> selected;
for (const auto& d : devices) {
if (std::string{d.EpName()} == "CPUExecutionProvider") {
selected.push_back(d);
break;
}
}

ASSERT_FALSE(selected.empty());

// Pick a status that the CPU EP would never return to ensure the value is set correctly.
OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
ASSERT_NO_FATAL_FAILURE({
status = Ort::GetModelCompatibilityForEpDevices(selected, "arbitrary-compat-string");
});

ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import platform
import sys
import unittest

from onnxruntime.capi.onnxruntime_pybind11_state import (
OrtCompiledModelCompatibility,
get_ep_devices,
get_model_compatibility_for_ep_devices,
)

# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed.
if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204
os.add_dll_directory(os.getcwd())


class TestEpCompatibility(unittest.TestCase):
def test_invalid_args(self):
# empty devices
with self.assertRaises(RuntimeError):
get_model_compatibility_for_ep_devices([], "info")
# None compatibility info should raise TypeError before native call
with self.assertRaises(TypeError):
get_model_compatibility_for_ep_devices(get_ep_devices(), None) # type: ignore[arg-type]

def test_basic_smoke(self):
devices = list(get_ep_devices())
if not devices:
self.skipTest("No EP devices available in this build")

# Always select CPUExecutionProvider; skip if not present.
cpu_devices = [d for d in devices if getattr(d, "ep_name", None) == "CPUExecutionProvider"]
if not cpu_devices:
self.skipTest("CPUExecutionProvider not available in this build")
selected = [cpu_devices[0]]

# API requires all devices belong to the same EP; we pass only one.
status = get_model_compatibility_for_ep_devices(selected, "arbitrary-compat-string")
self.assertEqual(status, OrtCompiledModelCompatibility.EP_NOT_APPLICABLE)


if __name__ == "__main__":
unittest.main()
Loading