diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs new file mode 100644 index 0000000000000..9f42bf2247529 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Runtime.InteropServices; + + /// + /// This class is used to set options for model compilation, and to produce a compiled model using those options. + /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. + /// + public class OrtModelCompilationOptions : SafeHandle + { + /// + /// Create a new OrtModelCompilationOptions object from SessionOptions. + /// + /// SessionOptions instance to read settings from. + public OrtModelCompilationOptions(SessionOptions sessionOptions) + : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions( + OrtEnv.Instance().Handle, sessionOptions.Handle, out handle)); + } + + /// + /// Compile the model using the options set in this object. + /// + public void CompileModel() + { + NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle)); + } + + + /// + /// Set the input model to compile. + /// + /// Path to ONNX model to compile. + public void SetInputModelPath(string path) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath)); + } + + /// + /// Set the input model to compile to be a byte array. + /// The input bytes are NOT copied and must remain valid while in use by ORT. + /// + /// Input model bytes. + public void SetInputModelFromBuffer(byte[] buffer) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer( + handle, buffer, (UIntPtr)buffer.Length)); + } + + /// + /// Set the path to write the compiled ONNX model to. + /// + /// Path to write compiled model to. + public void SetOutputModelPath(string path) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath)); + + } + + /// + /// Set the path to a file to write initializers as external data to, + /// and the threshold that determines when to write an initializer to the external data file. + /// + /// Path to file to write external data to. + /// Size at which an initializer will be written to external data. + public void SetOutputModelExternalInitializersFile(string filePath, ulong threshold) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + handle, platformPath, new UIntPtr(threshold))); + } + + // TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure. + // - Need something that wraps the allocator, pointer and size and is SafeHandle based. + // - When it is disposed we need to use the allocator to release the native buffer. + // - Need the 4 InferenceSession ctors that take byte[] for the model to be duplicated to handle this new + // wrapper type. + // Due to that making this API internal so we can test it. We can make it public when the other infrastructure + // is in place as it will change the signature of the API. + internal void SetOutputModelBuffer(OrtAllocator allocator, + ref IntPtr outputModelBufferPtr, ref UIntPtr outputModelBufferSizePtr) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer( + handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); + } + + /// + /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute + /// of EPContext nodes. + /// + /// Enable if true. Default is false. + public void SetEpContextEmbedMode(bool embed) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); + } + + internal IntPtr Handle => handle; + + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// Release the native instance of OrtModelCompilationOptions. + /// + /// true + protected override bool ReleaseHandle() + { + NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle); + handle = IntPtr.Zero; + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs new file mode 100644 index 0000000000000..3a87f87d124e9 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.ML.OnnxRuntime.CompileApi; + +using System; +using System.Runtime.InteropServices; + +// NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi +// See onnxruntime/core/session/compile_api.cc. +[StructLayout(LayoutKind.Sequential)] +public struct OrtCompileApi +{ + public IntPtr ReleaseModelCompilationOptions; + public IntPtr CreateModelCompilationOptionsFromSessionOptions; + public IntPtr ModelCompilationOptions_SetInputModelPath; + public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; + public IntPtr ModelCompilationOptions_SetOutputModelPath; + public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; + public IntPtr ModelCompilationOptions_SetOutputModelBuffer; + public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; + public IntPtr CompileModel; +} + +internal class NativeMethods +{ + private static OrtCompileApi _compileApi; + + // + // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. + // + // We populate the static members in the constructor of this class. + // + // The C# code will call the C++ API through the delegate instances in the static members. + // + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); + public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtSessionOptions* */ sessionOptions, + out IntPtr /* OrtModelCompilationOptions** */ outOptions); + public DOrtCreateModelCompilationOptionsFromSessionOptions + OrtCreateModelCompilationOptionsFromSessionOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ inputModelPath); + public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const void* */ inputModelData, + UIntPtr /* size_t */ inputModelDataSize); + public DOrtModelCompilationOptions_SetInputModelFromBuffer + OrtModelCompilationOptions_SetInputModelFromBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputModelPath); + public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, + UIntPtr /* size_t */ externalInitializerSizeThreshold); + public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* OrtAllocator* */ allocator, + ref IntPtr /* void** */ outputModelBufferPtr, + ref UIntPtr /* size_t* */ outputModelBufferSizePtr); + public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( + IntPtr /* OrtModelCompilationOptions* */ options, + bool embedEpContextInModel); + public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtModelCompilationOptions* */ modelOptions); + public DOrtCompileModel OrtCompileModel; + + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + { + +#if NETSTANDARD2_0 + IntPtr compileApiPtr = getCompileApi(); + _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); +#else + _compileApi = (OrtCompileApi)getCompileApi(); +#endif + + OrtReleaseModelCompilationOptions = + (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.ReleaseModelCompilationOptions, + typeof(DOrtReleaseModelCompilationOptions)); + + OrtCreateModelCompilationOptionsFromSessionOptions = + (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.CreateModelCompilationOptionsFromSessionOptions, + typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); + + OrtModelCompilationOptions_SetInputModelPath = + (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelPath, + typeof(DOrtModelCompilationOptions_SetInputModelPath)); + + OrtModelCompilationOptions_SetInputModelFromBuffer = + (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, + typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); + + OrtModelCompilationOptions_SetOutputModelPath = + (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelPath, + typeof(DOrtModelCompilationOptions_SetOutputModelPath)); + + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = + (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, + typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); + + OrtModelCompilationOptions_SetOutputModelBuffer = + (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelBuffer, + typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); + + OrtModelCompilationOptions_SetEpContextEmbedMode = + (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, + typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); + + OrtCompileModel = + (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.CompileModel, + typeof(DOrtCompileModel)); + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 77c35aac65b92..620c13b8641b5 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -336,12 +336,43 @@ public struct OrtApi public IntPtr GetModelEditorApi; public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; public IntPtr SessionOptionsSetLoadCancellationFlag; + + public IntPtr GetCompileApi; + + public IntPtr CreateKeyValuePairs; + public IntPtr AddKeyValuePair; + public IntPtr GetKeyValue; + public IntPtr GetKeyValuePairs; + public IntPtr RemoveKeyValuePair; + public IntPtr ReleaseKeyValuePairs; + + public IntPtr RegisterExecutionProviderLibrary; + public IntPtr UnregisterExecutionProviderLibrary; + + public IntPtr GetEpDevices; + + public IntPtr SessionOptionsAppendExecutionProvider_V2; + public IntPtr SessionOptionsSetEpSelectionPolicy; + + public IntPtr HardwareDevice_Type; + public IntPtr HardwareDevice_VendorId; + public IntPtr HardwareDevice_Vendor; + public IntPtr HardwareDevice_DeviceId; + public IntPtr HardwareDevice_Metadata; + + public IntPtr EpDevice_EpName; + public IntPtr EpDevice_EpVendor; + public IntPtr EpDevice_EpMetadata; + public IntPtr EpDevice_EpOptions; + public IntPtr EpDevice_Device; } internal static class NativeMethods { static OrtApi api_; + static internal CompileApi.NativeMethods CompileApi; + #if NETSTANDARD2_0 [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtGetApi(UInt32 version); @@ -582,6 +613,85 @@ static NativeMethods() typeof(DReleaseLoraAdapter)); OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer( api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter)); + + OrtGetCompileApi = (DOrtGetCompileApi)Marshal.GetDelegateForFunctionPointer( + api_.GetCompileApi, typeof(DOrtGetCompileApi)); + + // populate the CompileApi struct now that we have the delegate to get the compile API pointer. + CompileApi = new CompileApi.NativeMethods(OrtGetCompileApi); + + OrtCreateKeyValuePairs = (DOrtCreateKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.CreateKeyValuePairs, typeof(DOrtCreateKeyValuePairs)); + + OrtAddKeyValuePair = (DOrtAddKeyValuePair)Marshal.GetDelegateForFunctionPointer( + api_.AddKeyValuePair, typeof(DOrtAddKeyValuePair)); + + OrtGetKeyValue = (DOrtGetKeyValue)Marshal.GetDelegateForFunctionPointer( + api_.GetKeyValue, typeof(DOrtGetKeyValue)); + + OrtGetKeyValuePairs = (DOrtGetKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.GetKeyValuePairs, typeof(DOrtGetKeyValuePairs)); + + OrtRemoveKeyValuePair = (DOrtRemoveKeyValuePair)Marshal.GetDelegateForFunctionPointer( + api_.RemoveKeyValuePair, typeof(DOrtRemoveKeyValuePair)); + + OrtReleaseKeyValuePairs = (DOrtReleaseKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseKeyValuePairs, typeof(DOrtReleaseKeyValuePairs)); + + OrtHardwareDevice_Type = (DOrtHardwareDevice_Type)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Type, typeof(DOrtHardwareDevice_Type)); + + OrtHardwareDevice_VendorId = (DOrtHardwareDevice_VendorId)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_VendorId, typeof(DOrtHardwareDevice_VendorId)); + + OrtHardwareDevice_Vendor = (DOrtHardwareDevice_Vendor)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Vendor, typeof(DOrtHardwareDevice_Vendor)); + + OrtHardwareDevice_DeviceId = (DOrtHardwareDevice_DeviceId)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_DeviceId, typeof(DOrtHardwareDevice_DeviceId)); + + OrtHardwareDevice_Metadata = (DOrtHardwareDevice_Metadata)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Metadata, typeof(DOrtHardwareDevice_Metadata)); + + + OrtEpDevice_EpName = (DOrtEpDevice_EpName)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpName, typeof(DOrtEpDevice_EpName)); + + OrtEpDevice_EpVendor = (DOrtEpDevice_EpVendor)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpVendor, typeof(DOrtEpDevice_EpVendor)); + + OrtEpDevice_EpMetadata = (DOrtEpDevice_EpMetadata)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpMetadata, typeof(DOrtEpDevice_EpMetadata)); + + OrtEpDevice_EpOptions = (DOrtEpDevice_EpOptions)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpOptions, typeof(DOrtEpDevice_EpOptions)); + + OrtEpDevice_Device = (DOrtEpDevice_Device)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_Device, typeof(DOrtEpDevice_Device)); + + OrtRegisterExecutionProviderLibrary = + (DOrtRegisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( + api_.RegisterExecutionProviderLibrary, + typeof(DOrtRegisterExecutionProviderLibrary)); + + OrtUnregisterExecutionProviderLibrary = + (DOrtUnregisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( + api_.UnregisterExecutionProviderLibrary, + typeof(DOrtUnregisterExecutionProviderLibrary)); + + OrtGetEpDevices = (DOrtGetEpDevices)Marshal.GetDelegateForFunctionPointer( + api_.GetEpDevices, + typeof(DOrtGetEpDevices)); + + OrtSessionOptionsAppendExecutionProvider_V2 = + (DOrtSessionOptionsAppendExecutionProvider_V2)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider_V2, + typeof(DOrtSessionOptionsAppendExecutionProvider_V2)); + + OrtSessionOptionsSetEpSelectionPolicy = + (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsSetEpSelectionPolicy, + typeof(DSessionOptionsSetEpSelectionPolicy)); } internal class NativeLib @@ -823,7 +933,7 @@ internal class NativeLib IntPtr /* (OrtEnv*) */ environment, //[MarshalAs(UnmanagedType.LPStr)]string modelPath byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */ sessopnOptions, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, out IntPtr /**/ session); public static DOrtCreateSession OrtCreateSession; @@ -1350,7 +1460,7 @@ out IntPtr lora_adapter #endregion - #region RunOptions API +#region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); @@ -2153,7 +2263,168 @@ out IntPtr lora_adapter #endregion -#region Misc API +#region Compile API + +#if NETSTANDARD2_0 + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetCompileApi(); +#else + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi(); +#endif + public static DOrtGetCompileApi OrtGetCompileApi; +#endregion + +#region Auto EP API related + // + // OrtKeyValuePairs + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtCreateKeyValuePairs(out IntPtr /* OrtKeyValuePairs** */ kvps); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key, + byte[] /* const char* */ value); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtGetKeyValue(IntPtr /* const OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtGetKeyValuePairs(IntPtr /* const OrtKeyValuePairs* */ kvps, + out IntPtr /* const char* const** */ keys, + out IntPtr /* const char* const** */ values, + out UIntPtr /* size_t* */ numEntries); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseKeyValuePairs(IntPtr /* OrtKeyValuePairs* */ kvps); + + + public static DOrtCreateKeyValuePairs OrtCreateKeyValuePairs; + public static DOrtAddKeyValuePair OrtAddKeyValuePair; + public static DOrtGetKeyValue OrtGetKeyValue; + public static DOrtGetKeyValuePairs OrtGetKeyValuePairs; + public static DOrtRemoveKeyValuePair OrtRemoveKeyValuePair; + public static DOrtReleaseKeyValuePairs OrtReleaseKeyValuePairs; + + + // + // OrtHardwareDevice + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate int /* OrtHardwareDeviceType */ DOrtHardwareDevice_Type( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate uint /* uint32_t */ DOrtHardwareDevice_VendorId( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtHardwareDevice_Vendor( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate uint /* uint32_t */ DOrtHardwareDevice_DeviceId( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtHardwareDevice_Metadata( + IntPtr /* const OrtHardwareDevice* */ device); + + + public static DOrtHardwareDevice_Type OrtHardwareDevice_Type; + public static DOrtHardwareDevice_VendorId OrtHardwareDevice_VendorId; + public static DOrtHardwareDevice_Vendor OrtHardwareDevice_Vendor; + public static DOrtHardwareDevice_DeviceId OrtHardwareDevice_DeviceId; + public static DOrtHardwareDevice_Metadata OrtHardwareDevice_Metadata; + + // + // OrtEpDevice + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtEpDevice_EpName(IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtEpDevice_EpVendor(IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtEpDevice_EpMetadata( + IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtEpDevice_EpOptions( + IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtHardwareDevice* */ DOrtEpDevice_Device( + IntPtr /* const OrtEpDevice* */ ep_device); + + + public static DOrtEpDevice_EpName OrtEpDevice_EpName; + public static DOrtEpDevice_EpVendor OrtEpDevice_EpVendor; + public static DOrtEpDevice_EpMetadata OrtEpDevice_EpMetadata; + public static DOrtEpDevice_EpOptions OrtEpDevice_EpOptions; + public static DOrtEpDevice_Device OrtEpDevice_Device; + + // + // Auto Selection EP registration and selection customization + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibrary( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name, + byte[] /* const ORTCHAR_T* */ path); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtUnregisterExecutionProviderLibrary( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name); + + public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary; + public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetEpDevices( + IntPtr /* const OrtEnv* */ env, + out IntPtr /* const OrtEpDevice* const** */ ep_devices, + out UIntPtr /* size_t* */ num_ep_devices); + + public static DOrtGetEpDevices OrtGetEpDevices; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtSessionOptionsAppendExecutionProvider_V2( + IntPtr /* OrtSessionOptions* */ sess_options, + IntPtr /* OrtEnv* */ env, + IntPtr[] /* const OrtEpDevice* const* */ ep_devices, + UIntPtr /* size_t */ num_ep_devices, + IntPtr /* const char* const* */ ep_option_keys, // use OrtKeyValuePairs.GetKeyValuePairHandles + IntPtr /* const char* const* */ ep_option_vals, + UIntPtr /* size_t */ num_ep_options); + + public static DOrtSessionOptionsAppendExecutionProvider_V2 OrtSessionOptionsAppendExecutionProvider_V2; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtEpSelectionDelegate( + IntPtr /* OrtEpDevice** */ epDevices, + uint numDevices, + IntPtr /* OrtKeyValuePairs* */ modelMetadata, + IntPtr /* OrtKeyValuePairs* */ runtimeMetadata, + IntPtr /* OrtEpDevice** */ selected, + uint maxSelected, + out UIntPtr numSelected + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicy( + IntPtr /* OrtSessionOptions* */ session_options, + int /* OrtExecutionProviderDevicePolicy */ policy, + IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate); + public static DSessionOptionsSetEpSelectionPolicy OrtSessionOptionsSetEpSelectionPolicy; + + + #endregion + #region Misc API /// /// Queries all the execution providers supported in the native onnxruntime shared library diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index f4b2649f8d055..5c70808b82be1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime @@ -376,6 +377,68 @@ public OrtLoggingLevel EnvLogLevel } } + /// + /// Register an execution provider library with the OrtEnv instance. + /// A registered execution provider library can be used by all sessions created with the OrtEnv instance. + /// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can + /// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device. + /// + /// Coming: A selection policy can be specified and ORT will automatically select the best execution providers + /// and devices for the model. + /// + /// The name to register the library under. + /// The path to the library to register. + /// + /// + public void RegisterExecutionProviderLibrary(string registrationName, string libraryPath) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + var pathUtf8 = NativeOnnxValueHelper.GetPlatformSerializedString(libraryPath); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtRegisterExecutionProviderLibrary(handle, registrationNameUtf8, pathUtf8)); + } + + /// + /// Unregister an execution provider library from the OrtEnv instance. + /// + /// The name the library was registered under. + public void UnregisterExecutionProviderLibrary(string registrationName) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtUnregisterExecutionProviderLibrary(handle, registrationNameUtf8)); + } + + /// + /// Get the list of all execution provider and device combinations that are available. + /// These can be used to select the execution provider and device for a session. + /// + /// + /// + /// + public IReadOnlyList GetEpDevices() + { + IntPtr epDevicesPtr; + UIntPtr numEpDevices; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetEpDevices(handle, out epDevicesPtr, out numEpDevices)); + + int count = (int)numEpDevices; + var epDevices = new List(count); + + IntPtr[] epDevicePtrs = new IntPtr[count]; + Marshal.Copy(epDevicesPtr, epDevicePtrs, 0, count); + + foreach (var ptr in epDevicePtrs) + { + epDevices.Add(new OrtEpDevice(ptr)); + } + + return epDevices.AsReadOnly(); + } + #endregion #region SafeHandle overrides diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs new file mode 100644 index 0000000000000..e3947d900214e --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime +{ + /// + /// Represents the combination of an execution provider and a hardware device + /// that the execution provider can utilize. + /// + public class OrtEpDevice : SafeHandle + { + /// + /// Construct an OrtEpDevice from an existing native OrtEpDevice instance. + /// + /// Native OrtEpDevice handle. + internal OrtEpDevice(IntPtr epDeviceHandle) + : base(epDeviceHandle, ownsHandle: false) + { + } + + internal IntPtr Handle => handle; + + /// + /// The name of the execution provider. + /// + public string EpName + { + get + { + IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(namePtr); + } + } + + /// + /// The vendor who owns the execution provider. + /// + public string EpVendor + { + get + { + IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); + } + } + + /// + /// Execution provider metadata. + /// + public OrtKeyValuePairs EpMetadata + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(handle)); + } + } + + /// + /// Execution provider options. + /// + public OrtKeyValuePairs EpOptions + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(handle)); + } + } + + /// + /// The hardware device that the execution provider can utilize. + /// + public OrtHardwareDevice HardwareDevice + { + get + { + IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(handle); + return new OrtHardwareDevice(devicePtr); + } + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// No-op. OrtEpDevice is always read-only as the instance is owned by native ORT. + /// + /// True + protected override bool ReleaseHandle() + { + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs new file mode 100644 index 0000000000000..8e7caae90ff79 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Runtime.InteropServices; + + /// + /// Represents the type of hardware device. + /// Matches OrtHardwareDeviceType in the ORT C API. + /// + public enum OrtHardwareDeviceType + { + CPU = 0, + GPU = 1, + NPU = 2, + } + + /// + /// Represents a hardware device that is available on the current system. + /// + public class OrtHardwareDevice : SafeHandle + { + + /// + /// Construct an OrtHardwareDevice for a native OrtHardwareDevice instance. + /// + /// Native OrtHardwareDevice handle. + internal OrtHardwareDevice(IntPtr deviceHandle) + : base(deviceHandle, ownsHandle: false) + { + } + + /// + /// Get the type of hardware device. + /// + public OrtHardwareDeviceType Type + { + get + { + return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(handle); + } + } + + /// + /// Get the vendor ID of the hardware device if known. + /// + /// + /// For PCIe devices the vendor ID is the PCIe vendor ID. See https://pcisig.com/membership/member-companies. + /// + public uint VendorId + { + get + { + return NativeMethods.OrtHardwareDevice_VendorId(handle); + } + } + + /// + /// The vendor (manufacturer) of the hardware device. + /// + public string Vendor + { + get + { + IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); + } + } + + /// + /// Get the device ID of the hardware device if known. + /// + /// + /// This is the identifier of the device model. + /// PCIe device IDs can be looked up at https://www.pcilookup.com/ when combined with the VendorId. + /// It is NOT a unique identifier for the device in the current system. + /// + public uint DeviceId + { + get + { + return NativeMethods.OrtHardwareDevice_DeviceId(handle); + } + } + + /// + /// Get device metadata. + /// This may include information such as whether a GPU is discrete or integrated. + /// The available metadata will differ by platform and device type. + /// + public OrtKeyValuePairs Metadata + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(handle)); + } + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// No-op. OrtHardwareDevice is always read-only as the instance is owned by native ORT. + /// + /// True + protected override bool ReleaseHandle() + { + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs new file mode 100644 index 0000000000000..6a8d1037d9017 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Collections.Generic; + using System.Runtime.InteropServices; + + /// + /// Class to manage key-value pairs. + /// These are most often used for options and metadata. + /// + /// + /// + /// + public class OrtKeyValuePairs : SafeHandle + { + private readonly bool _createdHandle; + + // cache the values here for convenience. + // we could force a call to the C API every time in case something was changed in the background. + private Dictionary _keyValuePairs; + + /// + /// Create a new OrtKeyValuePairs instance. + /// + /// + /// A backing native instance is created and kept in sync with the C# content. + /// + public OrtKeyValuePairs() + : base(IntPtr.Zero, ownsHandle: true) + { + NativeMethods.OrtCreateKeyValuePairs(out handle); + _createdHandle = true; + _keyValuePairs = new Dictionary(); + } + + /// + /// Create a new OrtKeyValuePairs instance from an existing native OrtKeyValuePairs handle. + /// + /// Native OrtKeyValuePairs handle. + /// + /// The instance is read-only, so calling Add or Remove will throw an InvalidOperationError. + /// + internal OrtKeyValuePairs(IntPtr constHandle) + : base(constHandle, ownsHandle: false) + { + _createdHandle = false; + _keyValuePairs = GetLatest(); + } + + /// + /// Create a new OrtKeyValuePairs instance from a dictionary. + /// + /// Key-value pairs to add. + /// + /// A backing native instance is created and kept in sync with the C# content. + /// + public OrtKeyValuePairs(IReadOnlyDictionary keyValuePairs) + : base(IntPtr.Zero, ownsHandle: true) + { + NativeMethods.OrtCreateKeyValuePairs(out handle); + _createdHandle = true; + _keyValuePairs = new Dictionary(keyValuePairs != null ? keyValuePairs.Count : 0); + + if (keyValuePairs != null && keyValuePairs.Count > 0) + { + foreach (var kvp in keyValuePairs) + { + Add(kvp.Key, kvp.Value); + } + } + } + + /// + /// Current key-value pair entries. + /// + /// + /// Call Refresh() to update the cached values with the latest from the backing native instance. + /// In general that should not be required as it's not expected an OrtKeyValuePairs instance would be + /// updated by both native and C# code. + /// + public IReadOnlyDictionary Entries => _keyValuePairs; + + /// + /// Adds a key-value pair. Overrides any existing value for the key. + /// + /// Key to add. Must not be null or empty. + /// Value to add. May be empty. Must not be null. + public void Add(string key, string value) + { + if (!_createdHandle) + { + throw new InvalidOperationException($"{nameof(Add)} can only be called on instances you created."); + } + + var keyPtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(key); + var valuePtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(value); + NativeMethods.OrtAddKeyValuePair(handle, keyPtr, valuePtr); + _keyValuePairs[key] = value; // update the cached value + } + + /// + /// Update the cached values with the latest from the backing native instance as that is the source of truth. + /// + public void Refresh() + { + // refresh the cached values. + _keyValuePairs = GetLatest(); + } + + /// + /// Removes a key-value pair by key. Ignores keys that do not exist. + /// + /// Key to remove. + public void Remove(string key) + { + if (!_createdHandle) + { + throw new InvalidOperationException($"{nameof(Remove)} can only be called on instances you created."); + } + + var keyPtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(key); + NativeMethods.OrtRemoveKeyValuePair(handle, keyPtr); + + _keyValuePairs.Remove(key); // update the cached value + } + + // for internal usage to pass into the call to OrtSessionOptionsAppendExecutionProvider_V2 + // from SessionOptions::AppendExecutionProvider + internal void GetKeyValuePairHandles(out IntPtr keysHandle, out IntPtr valuesHandle, out UIntPtr numEntries) + { + if (IsInvalid) + { + throw new InvalidOperationException($"{nameof(GetKeyValuePairHandles)}: Invalid instance."); + } + + NativeMethods.OrtGetKeyValuePairs(handle, out keysHandle, out valuesHandle, out numEntries); + } + + /// + /// Fetch all the key/value pairs to make sure we are in sync with the C API. + /// + private Dictionary GetLatest() + { + var dict = new Dictionary(); + if (IsInvalid) + { + return dict; + } + + IntPtr keys, values; + UIntPtr numEntries; + NativeMethods.OrtGetKeyValuePairs(handle, out keys, out values, out numEntries); + + ulong count = numEntries.ToUInt64(); + int offset = 0; + for (ulong i = 0; i < count; i++, offset += IntPtr.Size) + { + IntPtr keyPtr = Marshal.ReadIntPtr(keys, offset); + IntPtr valuePtr = Marshal.ReadIntPtr(values, offset); + var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyPtr); + var value = NativeOnnxValueHelper.StringFromNativeUtf8(valuePtr); + dict.Add(key, value); + } + + return dict; + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Release the native instance of OrtKeyValuePairs if we own it. + /// + /// true + protected override bool ReleaseHandle() + { + if (_createdHandle) + { + NativeMethods.OrtReleaseKeyValuePairs(handle); + handle = IntPtr.Zero; + } + + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 7a5c3aaa19eac..f3c0287d2bf9d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -689,8 +689,8 @@ public static OrtValue CreateTensorValueFromMemory(T[] data, long[] shape) wh /// The method will attempt to pin managed memory so no copying occurs when data is passed down /// to native code. /// - /// Tensor object - /// discovered tensor element type + /// + /// Tensor object /// And instance of OrtValue constructed on top of the object [Experimental("SYSLIB5001")] public static OrtValue CreateTensorValueFromSystemNumericsTensorObject(SystemNumericsTensors.Tensor tensor) where T : unmanaged diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 9b0f183f03681..de6189e105f78 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -32,6 +32,21 @@ public enum ExecutionMode ORT_PARALLEL = 1, } + /// + /// Controls the execution provider selection when using automatic EP selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + public enum ExecutionProviderDevicePolicy + { + DEFAULT = 0, + PREFER_CPU = 1, + PREFER_NPU, + PREFER_GPU, + MAX_PERFORMANCE, + MAX_EFFICIENCY, + MIN_OVERALL_POWER, + } + /// /// Holds the options for creating an InferenceSession /// It forces the instantiation of the OrtEnv singleton. @@ -408,6 +423,82 @@ public void AppendExecutionProvider(string providerName, Dictionary + /// Select execution providers from the list of available execution providers and devices returned by + /// GetEpDevices. + /// + /// One or more OrtEpDevice instances may be provided in epDevices, but must all be for the same + /// execution provider. + /// + /// Make multiple calls to AppendExecutionProvider if you wish to use multiple execution providers. + /// + /// e.g. + /// - if execution provider 'A' has an OrtEpDevice for NPU and one for GPU and you wish to use it for + /// both devices, pass the two OrtEpDevice instances in the epDevices list in one call. + /// - if you wish to use execution provider 'B' for GPU and execution provider 'C' for CPU, + /// make two calls to AppendExecutionProvider, with one OrtEpDevice in the epDevices list in each call. + /// + /// The priority of the execution providers is set by the order in which they are appended. + /// Highest priority is first. + /// + /// OrtEnv that provided the OrtEpDevice instances via a call to GetEpDevices. + /// One or more OrtEpDevice instances to append. + /// These must all have the save EpName value. + /// Optional options to configure the execution provider. May be null. + /// epDevices was empty. + /// + public void AppendExecutionProvider(OrtEnv env, IReadOnlyList epDevices, + IReadOnlyDictionary epOptions) + { + if (epDevices == null || epDevices.Count == 0) + { + throw new ArgumentException("No execution provider devices were specified."); + } + + // Convert EpDevices to native pointers + IntPtr[] epDevicePtrs = new IntPtr[epDevices.Count]; + for (int i = 0; i < epDevices.Count; i++) + { + epDevicePtrs[i] = epDevices[i].Handle; + } + + if (epOptions != null && epOptions.Count > 0) + { + // this creates an OrtKeyValuePairs instance with a backing native instance + using var kvps = new OrtKeyValuePairs(epOptions); + + // get the native key/value handles so we can pass those straight through to the C API + // and not have to do any special marshaling here. + IntPtr epOptionsKeys, epOptionsValues; + UIntPtr epOptionsCount; + kvps.GetKeyValuePairHandles(out epOptionsKeys, out epOptionsValues, out epOptionsCount); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsAppendExecutionProvider_V2( + handle, + env.Handle, + epDevicePtrs, + (UIntPtr)epDevices.Count, + epOptionsKeys, + epOptionsValues, + epOptionsCount)); + } + else + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsAppendExecutionProvider_V2( + handle, + env.Handle, + epDevicePtrs, + (UIntPtr)epDevices.Count, + IntPtr.Zero, // EP options keys + IntPtr.Zero, // EP options values + UIntPtr.Zero)); // EP options count + } + + } + #endregion //ExecutionProviderAppends #region Public Methods @@ -452,8 +543,8 @@ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHand // End result of that is // SessionOptions.RegisterCustomOpLibrary calls NativeMethods.OrtRegisterCustomOpsLibrary_V2 // SessionOptions.RegisterCustomOpLibraryV2 calls NativeMethods.OrtRegisterCustomOpsLibrary - var utf8Path = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); - NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, + var platformPath = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, platformPath, out libraryHandle)); } @@ -536,6 +627,18 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue) var utf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName); NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, utf8, dimValue)); } + + /// + /// Set the execution provider selection policy if using automatic execution provider selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + /// Policy to use. + public void SetEpSelectionPolicy(ExecutionProviderDevicePolicy policy) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy, IntPtr.Zero)); + } + #endregion internal IntPtr Handle diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs new file mode 100644 index 0000000000000..72c165df56418 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Globalization; +using System.Runtime.InteropServices; +using Xunit; + + +public class CompileApiTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + + [Fact] + public void BasicUsage() + { + var so = new SessionOptions(); + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.SetInputModelPath("model.onnx"); + compileOptions.SetOutputModelPath("compiled_model.onnx"); + + compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); + compileOptions.SetEpContextEmbedMode(true); + + } + + // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + compileOptions.SetInputModelFromBuffer(model); + + // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. + // Due to that we need to allocate an IntPtr and UIntPtr here. + IntPtr bytePtr = new IntPtr(); + UIntPtr bytesSize = new UIntPtr(); + var allocator = OrtAllocator.DefaultInstance; + compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + + compileOptions.CompileModel(); + + Assert.NotEqual(IntPtr.Zero, bytePtr); + Assert.NotEqual(UIntPtr.Zero, bytesSize); + + byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; + Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); + + // Check the compiled model is valid + using (var session = new InferenceSession(compiledBytes, so)) + { + Assert.NotNull(session); + } + + allocator.FreeMemory(bytePtr); + } + } +} + +#endif \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs new file mode 100644 index 0000000000000..1aa4db15d275c --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Linq; +using System.IO; +using System.Runtime.InteropServices; +using Xunit; +using System.Collections.Generic; + +/// +/// Tests for auto ep selection/registration. +/// Includes testing of OrtHardwareDevice and OrtEpDevice as those only come from auto ep related code and we only +/// get read-only access to them (i.e. we can't directly create instances of them to test). +/// +public class OrtAutoEpTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + private void ReadHardwareDeviceValues(OrtHardwareDevice device) + { + Assert.True(device.Type == OrtHardwareDeviceType.CPU || + device.Type == OrtHardwareDeviceType.GPU || + device.Type == OrtHardwareDeviceType.NPU); + if (device.Type == OrtHardwareDeviceType.CPU) + { + Assert.NotEmpty(device.Vendor); + } + else + { + Assert.True(device.VendorId != 0); + Assert.True(device.DeviceId != 0); + } + + var metadata = device.Metadata; + Assert.NotNull(metadata); + foreach (var kvp in metadata.Entries) + { + Assert.NotEmpty(kvp.Key); + // Assert.NotEmpty(kvp.Value); this is allowed + } + } + + [Fact] + public void GetEpDevices() + { + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotNull(epDevices); + Assert.NotEmpty(epDevices); + foreach (var ep_device in epDevices) + { + Assert.NotEmpty(ep_device.EpName); + Assert.NotEmpty(ep_device.EpVendor); + var metadata = ep_device.EpMetadata; + Assert.NotNull(metadata); + var options = ep_device.EpOptions; + Assert.NotNull(options); + ReadHardwareDeviceValues(ep_device.HardwareDevice); + } + } + + [Fact] + public void RegisterUnregisterLibrary() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + // register. shouldn't throw + ortEnvInstance.RegisterExecutionProviderLibrary(epName, libFullPath); + + // check OrtEpDevice was found + var epDevices = ortEnvInstance.GetEpDevices(); + var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); + Assert.True(found); + + // unregister + ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + } + } + + [Fact] + public void AppendToSessionOptionsV2() + { + var runTest = (Func> getEpOptions) => + { + SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + + // cpu ep ignores the provider options so we can use any value in epOptions and it won't break. + List selectedEpDevices = epDevices.Where(d => d.EpName == "CPUExecutionProvider").ToList(); + + Dictionary epOptions = getEpOptions(); + sessionOptions.AppendExecutionProvider(ortEnvInstance, selectedEpDevices, epOptions); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model)) + { + Assert.NotNull(session); + } + }; + + runTest(() => + { + // null options + return null; + }); + + runTest(() => + { + // empty options + return new Dictionary(); + }); + + runTest(() => + { + // dummy options + return new Dictionary + { + { "random_key", "value" }, + }; + }); + } + + [Fact] + public void SetEpSelectionPolicy() + { + SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // doesn't matter what the value is. should fallback to ORT CPU EP + sessionOptions.SetEpSelectionPolicy(ExecutionProviderDevicePolicy.PREFER_GPU); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model)) + { + Assert.NotNull(session); + } + } +} +#endif \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs new file mode 100644 index 0000000000000..b89b970688d5f --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Xunit; + +namespace Microsoft.ML.OnnxRuntime.Tests; + +public class OrtKeyValuePairsTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + + [Fact] + public void CRUD() + { + using var kvp = new OrtKeyValuePairs(); + kvp.Add("key1", "value1"); + kvp.Add("key2", "value2"); + kvp.Add("key3", ""); // allowed + + Assert.Equal("value1", kvp.Entries["key1"]); + Assert.Equal("value2", kvp.Entries["key2"]); + Assert.Equal("", kvp.Entries["key3"]); + + kvp.Remove("key1"); + Assert.False(kvp.Entries.ContainsKey("key1")); + + kvp.Remove("invalid_key"); // shouldn't break + + Assert.Equal(2, kvp.Entries.Count); + + // refresh from the C API to make sure everything is in sync + kvp.Refresh(); + Assert.Equal(2, kvp.Entries.Count); + Assert.Equal("value2", kvp.Entries["key2"]); + Assert.Equal("", kvp.Entries["key3"]); + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index a8abcd2b4aa1c..ee3c8c69aa2ae 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -70,7 +70,8 @@ + $(NativeBuildOutputDir)\custom_op_library*.dll; + $(NativeBuildOutputDir)\example_plugin_ep.dll"> PreserveNewest false diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 7928f9b822cf0..7be518a39480f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -562,8 +562,9 @@ ORT_API(void, GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, ORT_API(void, RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key); ORT_API(void, ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs*); -ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* ep_name, const ORTCHAR_T* path); -ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name); +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, + const ORTCHAR_T* path); +ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); ORT_API_STATUS_IMPL(GetEpDevices, _In_ const OrtEnv* env, _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices);