diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index 2f1532d0643ae..3189b64349898 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -603,10 +603,6 @@ if(NOT (onnx_FOUND OR ONNX_FOUND)) # building ONNX from source
endif()
endif()
-if (onnxruntime_RUN_ONNX_TESTS)
- add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS)
-endif()
-
if(onnxruntime_ENABLE_DLPACK)
message(STATUS "dlpack is enabled.")
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 24cecf07e8e36..07e61fb210036 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -108,6 +108,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
+ ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
)
set(mlas_platform_preprocess_srcs
@@ -429,12 +430,16 @@ else()
${MLAS_SRC_DIR}/softmax_kernel_neon.cpp
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
+ ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
)
if (onnxruntime_USE_KLEIDIAI)
setup_kleidiai()
endif()
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
+ set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
+ PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
+
if (NOT APPLE)
set(mlas_platform_srcs
${mlas_platform_srcs}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
index c348184658e7e..bde39d9c6e6cc 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
@@ -4,6 +4,7 @@
namespace Microsoft.ML.OnnxRuntime
{
using System;
+ using System.Diagnostics;
using System.Runtime.InteropServices;
///
@@ -22,18 +23,19 @@ public enum OrtCompileApiFlags : uint
/// This class is used to set options for model compilation, and to produce a compiled model using those options.
/// See https://onnxruntime.ai/docs/api/c/ for further details of various options.
///
- public class OrtModelCompilationOptions : SafeHandle
+ public class OrtModelCompilationOptions : IDisposable
{
///
/// Create a new OrtModelCompilationOptions object from SessionOptions.
///
+ /// By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel()
+ /// to enable graph optimizations.
/// SessionOptions instance to read settings from.
public OrtModelCompilationOptions(SessionOptions sessionOptions)
- : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions(
- OrtEnv.Instance().Handle, sessionOptions.Handle, out handle));
+ OrtEnv.Instance().Handle, sessionOptions.Handle, out _handle));
}
///
@@ -41,7 +43,7 @@ public OrtModelCompilationOptions(SessionOptions sessionOptions)
///
public void CompileModel()
{
- NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle));
+ NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, _handle));
}
@@ -53,7 +55,7 @@ public void SetInputModelPath(string path)
{
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path);
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(_handle, platformPath));
}
///
@@ -65,7 +67,7 @@ public void SetInputModelFromBuffer(byte[] buffer)
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer(
- handle, buffer, (UIntPtr)buffer.Length));
+ _handle, buffer, (UIntPtr)buffer.Length));
}
///
@@ -76,7 +78,7 @@ public void SetOutputModelPath(string path)
{
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path);
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(_handle, platformPath));
}
@@ -91,7 +93,7 @@ public void SetOutputModelExternalInitializersFile(string filePath, ulong thresh
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath);
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile(
- handle, platformPath, new UIntPtr(threshold)));
+ _handle, platformPath, new UIntPtr(threshold)));
}
// TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure.
@@ -106,7 +108,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator,
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer(
- handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr));
+ _handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr));
}
///
@@ -117,7 +119,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator,
public void SetEpContextEmbedMode(bool embed)
{
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(_handle, embed));
}
///
@@ -127,26 +129,379 @@ public void SetEpContextEmbedMode(bool embed)
public void SetFlags(OrtCompileApiFlags flags)
{
NativeApiStatus.VerifySuccess(
- NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags));
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(_handle, (uint)flags));
}
- internal IntPtr Handle => handle;
+ ///
+ /// Sets information related to EP context binary file. The Ep uses this information to decide the
+ /// location and context binary file name when compiling with both the input and output models
+ /// stored in buffers.
+ ///
+ /// Path to the model directory.
+ /// The name of the model.
+ public void SetEpContextBinaryInformation(string outputDirectory, string modelName)
+ {
+ var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory);
+ var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName);
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation(
+ _handle, platformOutputDirectory, platformModelName));
+ }
+
+ ///
+ /// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified.
+ ///
+ /// The graph optimization level to set.
+ public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel)
+ {
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel(
+ _handle, graphOptimizationLevel));
+ }
+
+ ///
+ /// Delegate to write/save a buffer containing ONNX model bytes to a custom destination. The delegate
+ /// may be called repeatedly until the entire output model has been written out. Each call to the delegate
+ /// is expected to consume the entire buffer.
+ ///
+ /// The buffer to write out.
+ ///
+ public delegate void WriteBufferToDestinationDelegate(ReadOnlySpan buffer);
+
+ ///
+ /// Sets a delegate that is called by ORT to write out the output model's serialized ONNX bytes.
+ /// The provided delegate may be called repeatedly until the entire output model has been written out.
+ /// Each call to the delegate is expected to consume/handle the entire input buffer.
+ ///
+ /// The delegate called by ORT to write out the model.
+ public void SetOutputModelWriteDelegate(WriteBufferToDestinationDelegate writeBufferDelegate)
+ {
+ _writeBufferToDestinationDelegateState?.Dispose();
+ _writeBufferToDestinationDelegateState =
+ new DelegateResources(
+ new WriteBufferToDestinationConnector(writeBufferDelegate),
+ new NativeMethods.DOrtWriteBufferToDestinationDelegate(
+ WriteBufferToDestinationConnector.WriteBufferToDestinationDelegateWrapper));
+
+ IntPtr funcPtr = _writeBufferToDestinationDelegateState.GetFunctionPointerForDelegate();
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelWriteFunc(
+ _handle,
+ funcPtr,
+ _writeBufferToDestinationDelegateState.GetConnectorHandleAsPointer()));
+ }
+
+ ///
+ /// Delegate called by ORT for every initializer when generating the compiled model.
+ /// The delegate allows the user to determine whether the initializer should be stored within the compiled
+ /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate
+ /// implementation is responsible for writing the initializer data to a file.
+ ///
+ /// The initializer's name.
+ /// The readonly OrtValue instance containing the data, type, and
+ /// shape of the initializer.
+ /// May be null. If the initializer is originally stored externally,
+ /// this contains the file path, file offset, and data size. Otherwise, this is null.
+ /// A new OrtExternalInitializerInfo indicating the new location of the initializer.
+ /// Returns null if the initializer should be stored within the generated compiled model.
+ /// The return value may be null.
+ ///
+ public delegate OrtExternalInitializerInfo GetInitializerLocationDelegate(
+ string initializerName,
+ IReadOnlyOrtValue initializerValue,
+ IReadOnlyExternalInitializerInfo originalInitializerLocation);
+
+ ///
+ /// Sets a delegate that is called by ORT for every initializer when generating the compiled model.
+ /// The delegate allows the user to determine whether the initializer should be stored within the compiled
+ /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate
+ /// implementation is responsible for writing the initializer data to a file.
+ ///
+ /// The delegate called by ORT for every initializer.
+ public void SetOutputModelGetInitializerLocationDelegate(
+ GetInitializerLocationDelegate getInitializerLocationDelegate)
+ {
+ _getInitializerLocationDelegateState?.Dispose();
+ _getInitializerLocationDelegateState =
+ new DelegateResources(
+ new GetInitializerLocationConnector(getInitializerLocationDelegate),
+ new NativeMethods.DOrtGetInitializerLocationDelegate(
+ GetInitializerLocationConnector.GetInitializerLocationDelegateWrapper));
+
+ IntPtr funcPtr = _getInitializerLocationDelegateState.GetFunctionPointerForDelegate();
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc(
+ _handle,
+ funcPtr,
+ _getInitializerLocationDelegateState.GetConnectorHandleAsPointer()));
+ }
+
+ #region Delegate helpers
+ ///
+ /// Class to bridge the C# and native worlds for the "write buffer to destination" delegate
+ ///
+ private class WriteBufferToDestinationConnector
+ {
+ private readonly WriteBufferToDestinationDelegate _userDelegate;
+
+ internal WriteBufferToDestinationConnector(WriteBufferToDestinationDelegate writeBufferDelegate)
+ {
+ _userDelegate = writeBufferDelegate;
+ }
+
+ public static IntPtr WriteBufferToDestinationDelegateWrapper(IntPtr /* void* */ state,
+ IntPtr /* const void* */ buffer,
+ UIntPtr /* size_t */ bufferNumBytes)
+ {
+ try
+ {
+
+ WriteBufferToDestinationConnector connector = (WriteBufferToDestinationConnector)
+ GCHandle.FromIntPtr(state).Target;
+ ReadOnlySpan bufferSpan;
+
+ unsafe
+ {
+ // NOTE: A Span can only view 2GB of data. This is fine because ORT does not write out
+ // chunks that large. However, if we ever need to, the solution is to just write a loop here
+ // that repeatedly calls the delegate with smaller chunks of data.
+ bufferSpan = new ReadOnlySpan(buffer.ToPointer(), checked((int)bufferNumBytes));
+ }
+
+ connector._userDelegate(bufferSpan);
+ }
+ catch (Exception ex)
+ {
+ var error = $"The C# WriteBufferToDestination delegate threw an exception: {ex.Message}";
+ IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail,
+ NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error));
+ return status;
+ }
+
+ return IntPtr.Zero;
+ }
+ }
+
+ ///
+ /// Class to bridge the C# and native worlds for the "get initializer location" delegate
+ ///
+ private class GetInitializerLocationConnector
+ {
+ private readonly GetInitializerLocationDelegate _userDelegate;
+
+ internal GetInitializerLocationConnector(GetInitializerLocationDelegate getInitializerLocationDelegate)
+ {
+ _userDelegate = getInitializerLocationDelegate;
+ }
+
+ public static IntPtr GetInitializerLocationDelegateWrapper(
+ IntPtr /* void* */ state,
+ IntPtr /* const char* */ initializerName,
+ IntPtr /* const OrtValue* */ initializerValue,
+ IntPtr /* const OrtExternalInitializerInfo* */ originalInitializerLocation,
+ out IntPtr /* OrtExternalInitializerInfo** */ newInitializerLocationOutput)
+ {
+ newInitializerLocationOutput = IntPtr.Zero;
+
+ try
+ {
+
+ GetInitializerLocationConnector connector = (GetInitializerLocationConnector)GCHandle.
+ FromIntPtr(state).Target;
+ string utf8InitializerName = NativeOnnxValueHelper.StringFromNativeUtf8(initializerName);
+ IReadOnlyOrtValue readOnlyInitializerValue = new OrtValue(initializerValue, owned: false);
+ IReadOnlyExternalInitializerInfo readOnlyOriginalInitializerLocation = null;
+
+ if (originalInitializerLocation != IntPtr.Zero)
+ {
+ readOnlyOriginalInitializerLocation = new OrtExternalInitializerInfo(
+ originalInitializerLocation, ownsHandle: false);
+ }
+ // Call user's delegate, which may return the new location of the initializer.
+ OrtExternalInitializerInfo newInitializerLocation = connector._userDelegate(
+ utf8InitializerName, readOnlyInitializerValue, readOnlyOriginalInitializerLocation);
+
+ if (newInitializerLocation != null)
+ {
+ // Delegate returned info about a new location for the initializer.
+ // Can't guarantee that the new external info returned by user's delegate is not referenced
+ // by other C# code. ORT expects to own the new external info, so create a copy here and
+ // give it to ORT.
+ string newFilePath = newInitializerLocation.GetFilePath();
+ byte[] newFilePathBytes = NativeOnnxValueHelper.GetPlatformSerializedString(newFilePath);
+
+ IntPtr status = NativeMethods.OrtCreateExternalInitializerInfo(
+ newFilePathBytes,
+ newInitializerLocation.GetFileOffset(),
+ (UIntPtr)newInitializerLocation.GetByteSize(),
+ out newInitializerLocationOutput);
+
+ if (status != IntPtr.Zero)
+ {
+ return status;
+ }
+ }
+ else
+ {
+ // User's delegate did not return a new location for the initializer. ORT will store initializer
+ // within the generated compiled model.
+ newInitializerLocationOutput = IntPtr.Zero;
+ }
+ }
+ catch (Exception ex)
+ {
+ var error = $"The C# GetInitializerLocation delegate threw an exception: {ex.Message}";
+ IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail,
+ NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error));
+ return status;
+ }
+
+ return IntPtr.Zero;
+ }
+ }
///
- /// Indicates whether the native handle is invalid.
+ /// Disposable class that stores resources for a delegate provided by the user.
///
- public override bool IsInvalid => handle == IntPtr.Zero;
+ /// The type of the connector class
+ /// (e.g., WriteBufferToDestinationConnector)
+ /// The type of the native delegate.
+ private class DelegateResources : IDisposable
+ where Connector : class
+ where Delegate : class
+ {
+ public DelegateResources(Connector connector, Delegate @delegate)
+ {
+ _connector = connector;
+ _delegate = @delegate;
+ _connectorHandle = GCHandle.Alloc(_connector);
+ _delegateHandle = GCHandle.Alloc(_delegate);
+ }
+ internal IntPtr GetFunctionPointerForDelegate()
+ {
+ return Marshal.GetFunctionPointerForDelegate(_delegate);
+ }
+
+ internal IntPtr GetConnectorHandleAsPointer()
+ {
+ return GCHandle.ToIntPtr(_connectorHandle);
+ }
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ // Dispose other children disposables. We have none.
+ }
+
+ if (_connectorHandle.IsAllocated)
+ {
+ _connectorHandle.Free();
+ _connector = null;
+ }
+
+ if (_delegateHandle.IsAllocated)
+ {
+ _delegateHandle.Free();
+ _delegate = null;
+ }
+
+ _disposed = true;
+ }
+
+ ~DelegateResources()
+ {
+ Dispose(false);
+ }
+
+ private Connector _connector = null;
+ private Delegate _delegate = null;
+ private GCHandle _connectorHandle = default;
+ private GCHandle _delegateHandle = default;
+ private bool _disposed = false;
+ }
+ #endregion
+
+ #region IDispose implementation
///
- /// Release the native instance of OrtModelCompilationOptions.
+ /// IDispose implementation.
///
- /// true
- protected override bool ReleaseHandle()
+ public void Dispose()
{
- NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle);
- handle = IntPtr.Zero;
- return true;
+ Dispose(true);
+ GC.SuppressFinalize(this);
}
+
+ ///
+ /// IDispose implementation
+ ///
+ /// True if Dispose() has been called by the user-side code. False if
+ /// called by the runtime from inside the finalizer.
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ _writeBufferToDestinationDelegateState?.Dispose();
+ _getInitializerLocationDelegateState?.Dispose();
+ }
+
+ Debug.Assert(_handle != IntPtr.Zero);
+ NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(_handle);
+ _handle = IntPtr.Zero;
+ _disposed = true;
+ }
+
+ ///
+ /// Finalizer that releases the native handle if not already released by Dispose().
+ ///
+ ~OrtModelCompilationOptions()
+ {
+ Dispose(false);
+ }
+ #endregion
+
+ ///
+ /// Handle to the native OrtModelCompilationOptions object.
+ ///
+ private IntPtr _handle;
+
+ ///
+ /// True if this OrtModelCompilationOptions instance has already been disposed.
+ ///
+ private bool _disposed = false;
+
+ ///
+ /// Stores delegate state for the "write buffer to destination" delegate.
+ ///
+ private DelegateResources
+ _writeBufferToDestinationDelegateState = null;
+
+ ///
+ /// Stores delegate state for the "get initializer location" delegate.
+ ///
+ private DelegateResources
+ _getInitializerLocationDelegateState = null;
}
-}
\ No newline at end of file
+}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
index 3edc25b307a21..84020d84c9e73 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
@@ -21,6 +21,10 @@ public struct OrtCompileApi
public IntPtr ModelCompilationOptions_SetEpContextEmbedMode;
public IntPtr CompileModel;
public IntPtr ModelCompilationOptions_SetFlags;
+ public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation;
+ public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel;
+ public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc;
+ public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc;
}
internal class NativeMethods
@@ -101,6 +105,37 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile
uint flags);
public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ byte[] /* const ORTCHAR_T* */ outputDirectory,
+ byte[] /* const ORTCHAR_T* */ modelName);
+ public DOrtModelCompilationOptions_SetEpContextBinaryInformation
+ OrtModelCompilationOptions_SetEpContextBinaryInformation;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ GraphOptimizationLevel graphOptimizationLevel);
+ public DOrtModelCompilationOptions_SetGraphOptimizationLevel
+ OrtModelCompilationOptions_SetGraphOptimizationLevel;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelWriteFunc(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ IntPtr /* DOrtWriteBufferDelegate */ writeFunc,
+ IntPtr /* void* */ state);
+ public DOrtModelCompilationOptions_SetOutputModelWriteFunc
+ OrtModelCompilationOptions_SetOutputModelWriteFunc;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ IntPtr /* DOrtHandleInitializerDataDelegate */ handleInitializerFunc,
+ IntPtr /* void* */ state);
+ public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc
+ OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc;
+
internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi)
{
@@ -161,6 +196,27 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi
_compileApi.ModelCompilationOptions_SetFlags,
typeof(DOrtModelCompilationOptions_SetFlags));
+ OrtModelCompilationOptions_SetEpContextBinaryInformation =
+ (DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetEpContextBinaryInformation,
+ typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation));
+
+ OrtModelCompilationOptions_SetGraphOptimizationLevel =
+ (DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetGraphOptimizationLevel,
+ typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel));
+
+ OrtModelCompilationOptions_SetOutputModelWriteFunc =
+ (DOrtModelCompilationOptions_SetOutputModelWriteFunc)Marshal.GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetOutputModelWriteFunc,
+ typeof(DOrtModelCompilationOptions_SetOutputModelWriteFunc));
+
+ OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc =
+ (DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)Marshal.
+ GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
+ typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc));
+
}
}
}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 3c92400715740..53880308da261 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -450,6 +450,7 @@ public struct OrtApi
public IntPtr Graph_GetModelMetadata;
public IntPtr GetModelCompatibilityForEpDevices;
+ public IntPtr CreateExternalInitializerInfo;
}
internal static class NativeMethods
@@ -787,9 +788,35 @@ static NativeMethods()
api_.SessionOptionsSetEpSelectionPolicyDelegate,
typeof(DSessionOptionsSetEpSelectionPolicyDelegate));
+ OrtReleaseExternalInitializerInfo =
+ (DOrtReleaseExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer(
+ api_.ReleaseExternalInitializerInfo,
+ typeof(DOrtReleaseExternalInitializerInfo));
+
+ OrtExternalInitializerInfo_GetFilePath =
+ (DOrtExternalInitializerInfo_GetFilePath)Marshal.GetDelegateForFunctionPointer(
+ api_.ExternalInitializerInfo_GetFilePath,
+ typeof(DOrtExternalInitializerInfo_GetFilePath));
+
+ OrtExternalInitializerInfo_GetFileOffset =
+ (DOrtExternalInitializerInfo_GetFileOffset)Marshal.GetDelegateForFunctionPointer(
+ api_.ExternalInitializerInfo_GetFileOffset,
+ typeof(DOrtExternalInitializerInfo_GetFileOffset));
+
+ OrtExternalInitializerInfo_GetByteSize =
+ (DOrtExternalInitializerInfo_GetByteSize)Marshal.GetDelegateForFunctionPointer(
+ api_.ExternalInitializerInfo_GetByteSize,
+ typeof(DOrtExternalInitializerInfo_GetByteSize));
+
OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer(
api_.GetModelCompatibilityForEpDevices,
typeof(DOrtGetModelCompatibilityForEpDevices));
+
+ OrtCreateExternalInitializerInfo =
+ (DOrtCreateExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer(
+ api_.CreateExternalInitializerInfo,
+ typeof(DOrtCreateExternalInitializerInfo));
+
}
internal class NativeLib
@@ -2382,6 +2409,70 @@ out IntPtr lora_adapter
public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi();
#endif
public static DOrtGetCompileApi OrtGetCompileApi;
+
+ ///
+ /// Delegate called by ORT to write a buffer (ONNX model bytes) to a custom destination (e.g., file or stream).
+ ///
+ /// State that was provided in when the delegate was registered.
+ /// The buffer to write.
+ /// The size of the buffer in bytes.
+ /// OrtStatus*
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtWriteBufferToDestinationDelegate(
+ IntPtr /* void* */ state,
+ IntPtr /* const void* */ buffer,
+ UIntPtr /* size_t */ bufferNumBytes
+ );
+
+ ///
+ /// Function called by ORT to allow user to specify how an initializer should be saved while compiling
+ /// a model, that is, either written to an external file or stored within the model. ORT calls this function
+ /// for every initializer.
+ ///
+ /// State that was provided when the delegate was registered.
+ /// The initializer's name.
+ /// The OrtValue containing the initializer's data, type, and shape
+ /// The original initializer's location in an external file, or NULL.
+ /// Output parameter set to a new OrtExternalInitializerInfo instance
+ /// indicating the location where the function implementation stored the initializer data. If the function
+ /// implementation sets `newExternalInfo` to NULL, ORT stores the initializer within the generated model.
+ ///
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtGetInitializerLocationDelegate(
+ IntPtr /* void* */ state,
+ IntPtr /* const char* */ initializerName,
+ IntPtr /* const OrtValue* */ initializerValue,
+ IntPtr /* const OrtExternalInitializerInfo* */ externalInfo,
+ out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo
+ );
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate void DOrtReleaseExternalInitializerInfo(IntPtr /* OrtExternalInitializerInfo* */ info);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtCreateExternalInitializerInfo(
+ byte[] /* const ORTCHAR_T* */ filePath,
+ long /* int64_t */ fileOffset,
+ UIntPtr /* size_t */ byteSize,
+ out IntPtr /* OrtExternalInitializerInfo** */ outInfo);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* const ORTCHAR_T* */ DOrtExternalInitializerInfo_GetFilePath(
+ IntPtr /* const OrtExternalInitializerInfo* */ info);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate long /* int64_t */ DOrtExternalInitializerInfo_GetFileOffset(
+ IntPtr /* const OrtExternalInitializerInfo* */ info);
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate UIntPtr /* size_t */ DOrtExternalInitializerInfo_GetByteSize(
+ IntPtr /* const OrtExternalInitializerInfo* */ info);
+
+ public static DOrtReleaseExternalInitializerInfo OrtReleaseExternalInitializerInfo;
+ public static DOrtCreateExternalInitializerInfo OrtCreateExternalInitializerInfo;
+ public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath;
+ public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset;
+ public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize;
#endregion
#region Auto EP API related
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
index fc14be00ee47b..4611428ea12ef 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs
@@ -150,6 +150,45 @@ internal static byte[] GetPlatformSerializedString(string str)
else
return StringToZeroTerminatedUtf8(str);
}
+
+ ///
+ /// Converts a null-terminated path string that is pointed to by the given IntPtr handle into
+ /// a C# UTF-16 string.
+ ///
+ /// A path string on Windows is utf-16, but utf-8 on other operating systems.
+ ///
+ ///
+ internal static string StringFromNativePathString(IntPtr strPtr)
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ if (strPtr == IntPtr.Zero)
+ {
+ return string.Empty;
+ }
+
+ // Get length of utf16 string by checking for two 0 bytes in a row.
+ int length = 0;
+ while (Marshal.ReadInt16(strPtr, length * 2) != 0)
+ {
+ length += 1;
+ }
+
+ if (length == 0)
+ {
+ return string.Empty;
+ }
+
+ unsafe
+ {
+ return System.Text.Encoding.Unicode.GetString((byte*)strPtr, length * 2);
+ }
+ }
+ else
+ {
+ return StringFromNativeUtf8(strPtr);
+ }
+ }
}
// Guards an array of disposable objects on stack and disposes them in reverse order
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs
new file mode 100644
index 0000000000000..aca16e939ce21
--- /dev/null
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs
@@ -0,0 +1,136 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+
+namespace Microsoft.ML.OnnxRuntime
+{
+ using System;
+ using System.Diagnostics;
+ using System.Runtime.InteropServices;
+
+ ///
+ /// Class to that stores information about the file location where an "external" initializer is stored.
+ ///
+ ///
+ public class OrtExternalInitializerInfo : SafeHandle, IReadOnlyExternalInitializerInfo
+ {
+ // Set to false when constructed with an externally managed constant handle owned by ORT.
+ private readonly bool _ownsHandle = true;
+
+ ///
+ /// Create a new OrtExternalInitializerInfo instance.
+ ///
+ /// The path to the file that stores the initializer data.
+ /// The byte offset in the file where the data is stored.
+ /// The size of the data (in bytes) within the file.
+ public OrtExternalInitializerInfo(string filePath, long fileOffset, long byteSize)
+ : base(IntPtr.Zero, ownsHandle: true)
+ {
+ var platformFilePath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath);
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.OrtCreateExternalInitializerInfo(platformFilePath, fileOffset, (UIntPtr)byteSize, out handle));
+ _ownsHandle = true;
+ }
+
+ ///
+ /// Create a new OrtExternalInitializerInfo instance from an existing native OrtExternalInitializerInfo handle.
+ ///
+ /// Native OrtExternalInitializerInfo handle.
+ /// True if the OrtExternalInitializerInfo instance owns the native handle.
+ /// Defaults to false.
+ internal OrtExternalInitializerInfo(IntPtr constHandle, bool ownsHandle = false)
+ : base(IntPtr.Zero, ownsHandle)
+ {
+ Debug.Assert(constHandle != IntPtr.Zero);
+ SetHandle(constHandle);
+ _ownsHandle = ownsHandle;
+ }
+
+ ///
+ /// Get the file path to the file that store's the initializer's data.
+ ///
+ ///
+ /// The path is relative to the filesystem directory where the ONNX model was stored.
+ ///
+ /// The file path.
+ public string GetFilePath()
+ {
+ IntPtr filePathPtr = NativeMethods.OrtExternalInitializerInfo_GetFilePath(handle);
+ if (filePathPtr == IntPtr.Zero)
+ {
+ return string.Empty;
+ }
+
+ return NativeOnnxValueHelper.StringFromNativePathString(filePathPtr);
+ }
+
+ ///
+ /// Get the byte offset within the file where the initializer's data is stored.
+ ///
+ /// The file offset location.
+ public long GetFileOffset()
+ {
+ return NativeMethods.OrtExternalInitializerInfo_GetFileOffset(handle);
+ }
+
+ ///
+ /// Get the size in bytes of the initializer's data within the file.
+ ///
+ /// The size in bytes of the initializer data.
+ public long GetByteSize()
+ {
+ UIntPtr byteSize = NativeMethods.OrtExternalInitializerInfo_GetByteSize(handle);
+ return checked((long)byteSize);
+ }
+
+ ///
+ /// Indicates whether the native handle is invalid.
+ ///
+ public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
+
+ ///
+ /// Release the native instance of OrtExternalInitializerInfo if we own it.
+ ///
+ /// true on success and false on error.
+ protected override bool ReleaseHandle()
+ {
+ if (!_ownsHandle)
+ {
+ // Return false to indicate an error.
+ // ReleaseHandle() should not be called on a const handle that this class does not own.
+ return false;
+ }
+
+ NativeMethods.OrtReleaseExternalInitializerInfo(handle);
+ handle = IntPtr.Zero;
+ return true;
+ }
+ }
+
+ ///
+ /// Interface for all readonly methods implemented by OrtExternalInitializerInfo.
+ ///
+ public interface IReadOnlyExternalInitializerInfo
+ {
+ ///
+ /// Get the file path to the file that store's the initializer's data.
+ ///
+ ///
+ /// The path is relative to the filesystem directory where the ONNX model was stored.
+ ///
+ /// The file path.
+ string GetFilePath();
+
+ ///
+ /// Get the byte offset within the file where the initializer's data is stored.
+ ///
+ /// The file offset location.
+ long GetFileOffset();
+
+ ///
+ /// Get the size in bytes of the initializer's data within the file.
+ ///
+ /// The size in bytes of the initializer data.
+ long GetByteSize();
+ }
+}
\ No newline at end of file
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
index 01ee3aa5ae753..d848c63450ec1 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
@@ -33,6 +33,147 @@ public enum OnnxValueType
ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN)
}
+ ///
+ /// Interface for all readonly methods implemented by OrtValue.
+ ///
+ public interface IReadOnlyOrtValue
+ {
+ ///
+ /// Get the ONNX value type for the OrtValue (e.g., OnnxValueType.ONNX_TYPE_TENSOR).
+ ///
+ /// OnnxValueType
+ OnnxValueType OnnxType { get; }
+
+ ///
+ /// Returns true if OrtValue contains a tensor
+ ///
+ /// true if tensor
+ bool IsTensor { get; }
+
+ ///
+ /// Returns true if OrtValue contains a sparse tensor
+ ///
+ /// true if sparse tensor
+ bool IsSparseTensor { get; }
+
+ ///
+ /// Returns type information about the contained OnnxValue.
+ ///
+ /// a disposable instance of OrtTypeInfo
+ OrtTypeInfo GetTypeInfo();
+
+ ///
+ /// Obtains Tensor And Type Information from the OrtValue iff it contains a tensor.
+ /// Valid only for OrtValues that contain a tensor.
+ ///
+ /// A disposable instance of OrtTensorTypeAndShapeInfo
+ OrtTensorTypeAndShapeInfo GetTensorTypeAndShape();
+
+ ///
+ /// Returns the size of the tensor data in bytes.
+ ///
+ /// size of the tensor data in bytes
+ long GetTensorSizeInBytes();
+
+ ///
+ /// Returns OrtMemoryInfo iff this OrtValue contains a tensor or a sparse tensor.
+ ///
+ /// OrtMemoryInfo that describes the underlying memory allocation
+ ///
+ OrtMemoryInfo GetTensorMemoryInfo();
+
+ ///
+ /// Returns a ReadOnlySpan over tensor native buffer that
+ /// provides a read-only view.
+ ///
+ /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
+ /// To get memory descriptor use GetTensorMemoryInfo().
+ ///
+ /// OrtValue must contain a non-string tensor.
+ /// The span is valid as long as the OrtValue instance is alive (not disposed).
+ ///
+ ///
+ /// ReadOnlySpan
+ ///
+ ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged;
+
+#if NET8_0_OR_GREATER
+ ///
+ /// Returns a ReadOnlyTensorSpan over tensor native buffer that
+ /// provides a read-only view.
+ ///
+ /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
+ /// To get memory descriptor use GetTensorMemoryInfo().
+ ///
+ /// OrtValue must contain a non-string tensor.
+ /// The span is valid as long as the OrtValue instance is alive (not disposed).
+ ///
+ ///
+ /// ReadOnlySpan
+ ///
+ [Experimental("SYSLIB5001")]
+ SystemNumericsTensors.ReadOnlyTensorSpan GetTensorDataAsTensorSpan() where T : unmanaged;
+#endif
+
+ ///
+ /// Valid for composite ML types like map, sequence.
+ /// Returns 2 for map (keys, values) and N for sequence, where N is the number of elements
+ /// in the sequence.
+ ///
+ /// Element count
+ int GetValueCount();
+
+ ///
+ /// For non tensors return OrtValue element at the specified index.
+ /// For maps only indices 0 and 1 are valid. For sequences, [0..N) are valid.
+ /// See GetValueCount() to determine the valid range.
+ ///
+ ///
+ /// allocator to use
+ /// OrtValue disposable instance that points to the corresponding element of the composite type
+ OrtValue GetValue(int index, OrtAllocator allocator);
+
+ ///
+ /// Fetch string tensor element buffer pointer at the specified index,
+ /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
+ ///
+ /// Obtain TensorTypeAndShape to get shape and element count.
+ ///
+ /// flat string tensor element index
+ /// ReadOnlyMemory{char} backed by a managed char[]. Its lifespan is not
+ /// tied to the native buffer of OrtValue.
+ ReadOnlyMemory GetStringElementAsMemory(int index);
+
+ ///
+ /// Fetch string tensor element buffer pointer at the specified index,
+ /// copy/convert UTF-8 into a UTF-16 string and return it.
+ ///
+ /// Obtain TensorTypeAndShape to get shape and element count.
+ ///
+ /// flat string tensor element index
+ /// UTF-16 string instance
+ string GetStringElement(int index);
+
+ ///
+ /// Get a span over the native memory of the string tensor element.
+ /// The span is valid as long as the OrtValue is valid.
+ ///
+ /// This is useful if you want to perform your own UTF-8 decoding or
+ /// you do not care about decoding.
+ /// Obtain TensorTypeAndShape to get shape and element count.
+ ///
+ /// flat element index
+ /// ReadOnlySpan over UTF-8 bytes of the string tensor element
+ ReadOnlySpan GetStringElementAsSpan(int index);
+
+ ///
+ /// Convenience method to obtain all string tensor elements as a string array.
+ ///
+ /// string[]
+ ///
+ string[] GetStringTensorAsArray();
+ }
+
///
/// Represents a disposable OrtValue.
/// This class exposes a native instance of OrtValue.
@@ -44,7 +185,7 @@ public enum OnnxValueType
/// disposed properly, the pinned memory will continue to be pinned and interfere
/// with GC operation.
///
- public class OrtValue : IOrtValueOwner, IDisposable
+ public class OrtValue : IOrtValueOwner, IDisposable, IReadOnlyOrtValue
{
// OrtValues that are members of Sequences or Maps that map. They potentially map managed memory and we need to keep them around.
// this exists only when we deal with compose ML types.
@@ -52,11 +193,20 @@ public class OrtValue : IOrtValueOwner, IDisposable
private IntPtr _handle;
private MemoryHandle? _memHandle; // Present when the OrtValue is created on top of managed memory
private bool _disposed;
+ private bool _owned = true;
- internal OrtValue(IntPtr handle)
+ ///
+ /// Constructs OrtValue from a native handle. If `owned` is true, the OrtValue instance takes
+ /// ownership of the native handle and disposes it when the OrtValue instance is disposed.
+ ///
+ /// The native OrtValue handle.
+ /// True if this class instance owns the handle. If false, the handle
+ /// will not be released. Defaults to true.
+ internal OrtValue(IntPtr handle, bool owned = true)
{
_handle = handle;
InitOnnxType();
+ _owned = owned;
}
///
@@ -1464,7 +1614,10 @@ protected virtual void Dispose(bool disposing)
}
Debug.Assert(_handle != IntPtr.Zero);
- NativeMethods.OrtReleaseValue(_handle);
+ if (_owned)
+ {
+ NativeMethods.OrtReleaseValue(_handle);
+ }
_handle = IntPtr.Zero;
_disposed = true;
}
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
index bf576b54d8b45..fe2cab57658c8 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
@@ -21,102 +21,249 @@ public class CompileApiTests
[Fact]
public void BasicUsage()
{
- var so = new SessionOptions();
- using (var compileOptions = new OrtModelCompilationOptions(so))
+ using (var sessionOptions = new SessionOptions())
{
- // mainly checking these don't throw which ensures all the plumbing for the binding works.
- compileOptions.SetInputModelPath("model.onnx");
- compileOptions.SetOutputModelPath("compiled_model.onnx");
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ // mainly checking these don't throw which ensures all the plumbing for the binding works.
+ compileOptions.SetInputModelPath("model.onnx");
+ compileOptions.SetOutputModelPath("compiled_model.onnx");
- compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
- compileOptions.SetEpContextEmbedMode(true);
+ compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
+ compileOptions.SetEpContextEmbedMode(true);
+ compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC);
- }
+ }
- // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer
- using (var compileOptions = new OrtModelCompilationOptions(so))
- {
- var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
- compileOptions.SetInputModelFromBuffer(model);
+ // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ compileOptions.SetInputModelFromBuffer(model);
- // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile.
- // Due to that we need to allocate an IntPtr and UIntPtr here.
- IntPtr bytePtr = new IntPtr();
- UIntPtr bytesSize = new UIntPtr();
- var allocator = OrtAllocator.DefaultInstance;
- compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
+ // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile.
+ // Due to that we need to allocate an IntPtr and UIntPtr here.
+ IntPtr bytePtr = new IntPtr();
+ UIntPtr bytesSize = new UIntPtr();
+ var allocator = OrtAllocator.DefaultInstance;
+ compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
+ compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx");
- compileOptions.CompileModel();
+ compileOptions.CompileModel();
- Assert.NotEqual(IntPtr.Zero, bytePtr);
- Assert.NotEqual(UIntPtr.Zero, bytesSize);
+ Assert.NotEqual(IntPtr.Zero, bytePtr);
+ Assert.NotEqual(UIntPtr.Zero, bytesSize);
- byte[] compiledBytes = new byte[bytesSize.ToUInt64()];
- Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32());
+ byte[] compiledBytes = new byte[bytesSize.ToUInt64()];
+ Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32());
- // Check the compiled model is valid
- using (var session = new InferenceSession(compiledBytes, so))
- {
- Assert.NotNull(session);
+ // Check the compiled model is valid
+ using (var session = new InferenceSession(compiledBytes, sessionOptions))
+ {
+ Assert.NotNull(session);
+ }
+
+ allocator.FreeMemory(bytePtr);
}
- allocator.FreeMemory(bytePtr);
- }
+ // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate
+ // any compiled EPContext nodes, so expect an ORT_FAIL error.
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ var output_model_file = "should_not_generate.onnx";
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(output_model_file);
+ compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED);
- // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate
- // any compiled EPContext nodes, so expect an ORT_FAIL error.
- using (var compileOptions = new OrtModelCompilationOptions(so))
- {
- var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
- var output_model_file = "should_not_generate.onnx";
- compileOptions.SetInputModelFromBuffer(model);
- compileOptions.SetOutputModelPath(output_model_file);
- compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED);
+ // compile should fail
+ try
+ {
+ compileOptions.CompileModel();
+ Assert.Fail("CompileModel() should have thrown an exception");
+ }
+ catch (OnnxRuntimeException ex)
+ {
+ Assert.Contains("Unable to compile any nodes", ex.Message);
+ }
- // compile should fail
+ Assert.False(File.Exists(output_model_file)); // Output file should not be generated.
+ }
+
+ // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS.
+ var outputModelFile = "squeezenet_ctx.onnx";
try
{
- compileOptions.CompileModel();
- Assert.Fail("CompileModel() should have thrown an exception");
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+
+ // Compile and generate an output model.
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(outputModelFile);
+ compileOptions.CompileModel();
+ Assert.True(File.Exists(outputModelFile));
+
+ // Try to compile again with flag that prevents replacing an existing file.
+ // Expect failure.
+ compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS);
+
+ // compile should fail
+ try
+ {
+ compileOptions.CompileModel();
+ Assert.Fail("CompileModel() should have thrown an exception");
+ }
+ catch (OnnxRuntimeException ex)
+ {
+ Assert.Contains("exists already", ex.Message);
+ }
+ }
}
- catch (OnnxRuntimeException ex)
+ finally
{
- Assert.Contains("Unable to compile any nodes", ex.Message);
+ if (File.Exists(outputModelFile))
+ {
+ // This file is created by ORT, so we delete it manually in finally block.
+ File.Delete(outputModelFile);
+ }
}
-
- Assert.False(File.Exists(output_model_file)); // Output file should not be generated.
}
+ }
+
+ [Fact]
+ public void WriteOutModelWithDelegate()
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ var outputModelFilePath = "squeezenet_write_delegate_ctx.onnx";
- // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS.
- using (var compileOptions = new OrtModelCompilationOptions(so))
+ using (FileStream fs = new FileStream(outputModelFilePath, FileMode.Create, FileAccess.Write, FileShare.None,
+ 4096, FileOptions.DeleteOnClose))
+ using (var sessionOptions = new SessionOptions())
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
{
- var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
- var output_model_file = "squeezenet_ctx.onnx";
+ void BasicWriteBufferDelegate(ReadOnlySpan buffer)
+ {
+ Assert.True(buffer.Length > 0);
+ fs.Write(buffer.ToArray(), 0, buffer.Length); // Write it out to a file
+ }
// Compile and generate an output model.
compileOptions.SetInputModelFromBuffer(model);
- compileOptions.SetOutputModelPath(output_model_file);
+ compileOptions.SetOutputModelWriteDelegate(BasicWriteBufferDelegate);
compileOptions.CompileModel();
- Assert.True(File.Exists(output_model_file));
+ Assert.True(File.Exists(outputModelFilePath));
+ }
+ }
- // Try to compile again with flag that prevents replacing an existing file.
- // Expect failure.
- compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS);
+ [Fact]
+ public void BasicGetInitializerLocationDelegate()
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+ var outputModelFilePath = "squeezenet_handle_initializer_delegate_ctx.onnx";
+ var initializersFilePath = "squeezenet_handle_initializer_delegate_ctx.bin";
- // compile should fail
- try
+ try
+ {
+ using (FileStream fs = new FileStream(initializersFilePath, FileMode.Create, FileAccess.Write,
+ FileShare.None, 4096, FileOptions.DeleteOnClose))
+ using (var sessionOptions = new SessionOptions())
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
{
+ // Custom delegate that stores large initializers in a new file.
+ OrtExternalInitializerInfo BasicHandleInitializer(
+ string initializerName, IReadOnlyOrtValue initializerValue,
+ IReadOnlyExternalInitializerInfo originalInitializerLocation)
+ {
+ Assert.True(initializerName.Length > 0);
+
+ var byteSize = initializerValue.GetTensorSizeInBytes();
+ if (byteSize <= 64)
+ {
+ // Keep small initializers stored within model.
+ return null;
+ }
+
+ long byteOffset = fs.Position;
+ ReadOnlySpan dataSpan = initializerValue.GetTensorDataAsSpan();
+ fs.Write(dataSpan.ToArray(), 0, dataSpan.Length); // Write it out to a file
+
+ // Return the data's new location.
+ return new OrtExternalInitializerInfo(initializersFilePath, byteOffset, byteSize);
+ }
+
+ // Compile and generate an output model.
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(outputModelFilePath);
+ compileOptions.SetOutputModelGetInitializerLocationDelegate(BasicHandleInitializer);
compileOptions.CompileModel();
- Assert.Fail("CompileModel() should have thrown an exception");
+ Assert.True(File.Exists(outputModelFilePath));
}
- catch (OnnxRuntimeException ex)
+ }
+ finally
+ {
+ if (File.Exists(outputModelFilePath))
{
- Assert.Contains("exists already", ex.Message);
+ // This file is created by ORT, so we delete it manually in finally block.
+ File.Delete(outputModelFilePath);
}
+ }
+ }
+
+ [Fact]
+ public void GetInitializerLocationDelegateThatReusesExternalInitializers()
+ {
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("conv_qdq_external_ini.onnx");
+ var outputModelFilePath = "conv_qdq_external_ini.reuse.ctx.onnx";
+ bool reusedExternalInitializers = false;
+
+ try
+ {
+ using (var sessionOptions = new SessionOptions())
+ using (var compileOptions = new OrtModelCompilationOptions(sessionOptions))
+ {
+ // Custom delegate that reuses the original external initializer file.
+ OrtExternalInitializerInfo ReuseExternalInitializers(
+ string initializerName, IReadOnlyOrtValue initializerValue,
+ IReadOnlyExternalInitializerInfo originalInitializerLocation)
+ {
+ Assert.True(initializerName.Length > 0);
+
+ if (originalInitializerLocation != null)
+ {
+ reusedExternalInitializers = true; // For test assertion only
+ string originalFilePath = originalInitializerLocation.GetFilePath();
+ long originalFileOffset = originalInitializerLocation.GetFileOffset();
+ long originalByteSize = originalInitializerLocation.GetByteSize();
+
+ Assert.True(originalFilePath.Length > 0);
+ Assert.True(originalFileOffset >= 0);
+ Assert.True(originalByteSize > 0);
- if (File.Exists(output_model_file))
+ // This initializer comes from an external file. Reuse it for compiled model.
+ return new OrtExternalInitializerInfo(originalFilePath, originalFileOffset, originalByteSize);
+ }
+
+ // Otherwise, embed initializers that were not originally external.
+ return null;
+ }
+
+ // Compile and generate an output model.
+ compileOptions.SetInputModelFromBuffer(model);
+ compileOptions.SetOutputModelPath(outputModelFilePath);
+ compileOptions.SetOutputModelGetInitializerLocationDelegate(ReuseExternalInitializers);
+ compileOptions.CompileModel();
+
+ Assert.True(File.Exists(outputModelFilePath));
+ Assert.True(reusedExternalInitializers);
+ }
+ }
+ finally
+ {
+ if (File.Exists(outputModelFilePath))
{
- File.Delete(output_model_file);
+ // This file is created by ORT, so we delete it manually in finally block.
+ File.Delete(outputModelFilePath);
}
}
}
diff --git a/csharp/testdata/conv_qdq_external_ini.bin b/csharp/testdata/conv_qdq_external_ini.bin
new file mode 100644
index 0000000000000..89eea0dba1fa4
Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.bin differ
diff --git a/csharp/testdata/conv_qdq_external_ini.onnx b/csharp/testdata/conv_qdq_external_ini.onnx
new file mode 100644
index 0000000000000..c53e1f3ad4d9b
Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.onnx differ
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f3dcde1abe37a..cbfc38068ac2a 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3079,6 +3079,17 @@ This version of the operator has been available since version 1 of the 'com.micr
Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).
+
+ The SwiGLU (Swish-Gated Linear Unit) activation function is like:
+ g = xW + b
+ l = xV + c
+ G = clamp(g, max=limit)
+ L = clamp(l, min=-limit, max=limit)
+ swiglu = G * sigmoid(alpha * G) * (L + beta)
+ where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
+ When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
+ When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
+ When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.
#### Version
@@ -3088,12 +3099,20 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
+- activation_alpha : float
+- Alpha parameter used in activation function.
+- activation_beta : float
+- Beta parameter used in activation function.
- activation_type : string
-- Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+- Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
- k : int
- Number of top experts to select from expert pool
- normalize_routing_weights : int
- Whether to normalize routing weights
+- swiglu_fusion : int
+- 0: not fused, 1: fused and interleaved. 2: fused and not interleaved.
+- swiglu_limit : float
+- The limit used to clamp in SwiGLU. No clamp when limit is not provided.
- use_sparse_mixer : int
- Whether to use sparse mixer
@@ -3106,15 +3125,15 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
-3D input tensor with shape (num_experts, hidden_size, inter_size)
+3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu
fc1_experts_bias (optional) : T
-2D optional input tensor with shape (num_experts, inter_size)
+2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T
-3D input tensor with shape (num_experts, inter_size, hidden_size)
+3D input tensor with shape (num_experts, hidden_size, inter_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T
-3D optional input tensor with shape (num_experts, hidden_size, inter_size)
+3D optional input tensor with shape (num_experts, inter_size, hidden_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -3129,8 +3148,8 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float), tensor(float16)
-- Constrain input and output types to float or float16 tensors.
+- T : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
@@ -4522,14 +4541,22 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
+- activation_alpha : float
+- Alpha parameter used in activation function.
+- activation_beta : float
+- Beta parameter used in activation function.
- activation_type : string
-- Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+- Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
- expert_weight_bits : int
- Number of bits used in quantized weights. Default is 4 bits
- k : int
- Number of top experts to select from expert pool
- normalize_routing_weights : int
- Whether to normalize routing weights
+- swiglu_fusion : int
+- 0: not fused, 1: fused and interleaved. 2: fused and not interleaved.
+- swiglu_limit : float
+- The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.
- use_sparse_mixer : int
- Whether to use sparse mixer
@@ -4542,20 +4569,20 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-fc1_scales : T
-2D input tensor with shape (num_experts, inter_size)
+3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.
+fc1_scales : T2
+2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-2D optional input tensor with shape (num_experts, inter_size)
+2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
-3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
-fc2_scales : T
+3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits
+fc2_scales : T2
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
-3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-fc3_scales (optional) : T
+3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
+fc3_scales (optional) : T2
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -4571,10 +4598,12 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float16)
-- Constrain input and output types to float or float16 tensors.
+- T : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
- T1 : tensor(uint8)
- Constrain weights type to uint8 tensors.
+- T2 : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain scales type to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 3b70e5da8b3e4..660c63d056335 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -562,6 +562,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
+|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
@@ -937,6 +938,7 @@ Do not modify directly.*
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)|
@@ -949,7 +951,7 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)|
-|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -957,7 +959,7 @@ Do not modify directly.*
|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
-|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)|
+|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)|
|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h
index 375f0a4dc8dd2..e59a803d97629 100644
--- a/include/onnxruntime/core/framework/op_kernel.h
+++ b/include/onnxruntime/core/framework/op_kernel.h
@@ -305,6 +305,24 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
}
+#define ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name) \
+ provider##_##name##_##domain##_ver##ver##_##type1##_##type2##_##type3
+
+#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \
+ class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \
+ template <> \
+ KernelCreateInfo \
+ BuildKernelCreateInfo() { \
+ return KernelCreateInfo( \
+ builder.SetName(#name) \
+ .SetDomain(domain) \
+ .SinceVersion(ver) \
+ .Provider(provider) \
+ .Build(), \
+ static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \
+ out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
+ }
+
#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
provider##_##name##_##domain##_ver##startver##_##endver##_##type
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index 866892979b749..9a0708d72b4f8 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -1247,6 +1247,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const std::filesystem::path& model_file_path,
const ModelSavingOptions& model_saving_options) const;
+ ///
+ /// Serialize the Graph to a onnx::GraphProto. Caller provides a function that determines where each initializer
+ /// is stored (i.e., either in an external file or within the model).
+ ///
+ /// Function called for every initializer.
+ /// Opaque user state passed to the handle_initializer_func.
+ /// Output parameter set to the serialized onnx::GraphProto.
+ /// A status indicating success or an error.
+ common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const;
+
/** Gets the ISchemaRegistry instances being used with this Graph. */
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const;
@@ -1664,6 +1676,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::ostream& external_stream,
int64_t& external_offset) const;
+ Status ToGraphProtoWithCustomInitializerHandlingImpl(OrtGetInitializerLocationFunc handle_initializer_func,
+ void* state,
+ /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
#endif
Version IrVersion() const noexcept {
diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h
index a32f465e44adf..026fc3b2dc0a0 100644
--- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h
+++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h
@@ -34,6 +34,7 @@ constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes";
constexpr const char* kCudaGraphEnable = "enable_cuda_graph";
constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable";
constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer";
+constexpr const char* kRuntimeCacheFile = "nv_runtime_cache_path";
} // namespace provider_option_names
namespace run_option_names {
diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
index 28ce4439fdc7e..e2b2aff2011fe 100644
--- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
+++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
@@ -203,415 +203,331 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \
do { \
- OrtStatus* _status = (fn); \
- if (_status != nullptr) { \
- return Ort::Status{_status}; \
+ Ort::Status _status{(fn)}; \
+ if (!_status.IsOK()) { \
+ return _status; \
} \
} while (0)
#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \
- do { \
- Ort::Status _status = (fn); \
- if (!_status.IsOK()) { \
- return _status; \
- } \
- } while (0)
+ ORT_EP_UTILS_C_RETURN_IF_ERROR(fn)
-#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \
- do { \
- if ((cond)) { \
- return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \
- } \
+#define ORT_EP_UTILS_C_RETURN_IF(cond, msg) \
+ do { \
+ if ((cond)) { \
+ return Ort::Status{msg, ORT_FAIL}; \
+ } \
} while (0)
namespace OrtEpUtils {
-static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info,
+static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
bool get_symbolic_dims,
/*out*/ ONNXTensorElementDataType& elem_type,
/*out*/ std::vector& dims,
/*out*/ std::vector& symbolic_dims);
-static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto);
-static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
+static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto);
+static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto);
-Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
+Ort::Status OrtGraphToProto(const OrtGraph& graph,
onnx::GraphProto& graph_proto,
HandleInitializerDataFunc handle_initializer_data_func) {
- const OrtApi& ort_api = Ort::GetApi();
-
- //
- // Set GraphProto metadata
- //
- const char* graph_name = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name));
- graph_proto.set_name(graph_name);
- graph_proto.set_doc_string("Serialized from OrtGraph");
-
- //
- // Set GraphProto inputs and outputs
- //
- size_t num_graph_inputs = 0;
- size_t num_graph_outputs = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs));
-
- std::vector graph_inputs(num_graph_inputs);
- std::vector graph_outputs(num_graph_outputs);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size()));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size()));
-
- for (const OrtValueInfo* ort_value_info : graph_inputs) {
- onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add();
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto));
- }
-
- for (const OrtValueInfo* ort_value_info : graph_outputs) {
- onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add();
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto));
- }
-
- //
- // Set GraphProto nodes, value_infos, and initializers.
- //
-
- // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer.
- // A std::map maintains its elements in a stable ordering.
- std::map value_infos; // For GraphProto.value_info
- std::map initializer_value_infos; // For GraphProto.initializer
-
- // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`.
- // Optionally returns the OrtValueInfo name to the caller.
- auto collect_value_info = [&ort_api, &value_infos,
- &initializer_value_infos](const OrtValueInfo& ort_value_info,
- /*out*/ const char** value_name_out = nullptr) -> Ort::Status {
- const char* value_name = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name));
-
- if (value_name_out != nullptr) {
- *value_name_out = value_name;
+ try {
+ Ort::ConstGraph ort_graph{&graph};
+ //
+ // Set GraphProto metadata
+ //
+ auto graph_name = ort_graph.GetName();
+ graph_proto.set_name(graph_name);
+ graph_proto.set_doc_string("Serialized from OrtGraph");
+
+ //
+ // Set GraphProto inputs and outputs
+ //
+ std::vector graph_inputs = ort_graph.GetInputs();
+ std::vector graph_outputs = ort_graph.GetOutputs();
+
+ for (const auto& ort_value_info : graph_inputs) {
+ onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add();
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto));
}
- if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) {
- return Ort::Status{nullptr}; // Already processed this OrtValueInfo.
+ for (const auto& ort_value_info : graph_outputs) {
+ onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add();
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto));
}
- bool is_required_graph_input = false;
- bool is_optional_graph_input = false;
- bool is_graph_output = false;
- bool is_constant_initializer = false;
- bool is_from_outer_scope = false;
-
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope));
-
- // Don't add graph inputs or graph outputs to GraphProto's list of value_infos.
- // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors.
- // For values defined in an outer scope, just add the value info but not the initializer.
- if (is_from_outer_scope) {
- value_infos.emplace(value_name, &ort_value_info);
- } else if (is_optional_graph_input) {
- initializer_value_infos.emplace(value_name, &ort_value_info);
- } else if (is_constant_initializer) {
- value_infos.emplace(value_name, &ort_value_info);
- initializer_value_infos.emplace(value_name, &ort_value_info);
- } else if (!is_required_graph_input && !is_graph_output) {
- value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo.
- }
+ //
+ // Set GraphProto nodes, value_infos, and initializers.
+ //
+
+ // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer.
+ // A std::map maintains its elements in a stable ordering.
+ std::map value_infos; // For GraphProto.value_info
+ std::map initializer_value_infos; // For GraphProto.initializer
+
+ // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`.
+ // Optionally returns the OrtValueInfo name to the caller.
+ auto collect_value_info = [&value_infos,
+ &initializer_value_infos](Ort::ConstValueInfo ort_value_info,
+ /*out*/ std::optional& value_name_out) {
+ auto value_name = ort_value_info.GetName();
+
+ if (value_name_out) {
+ *value_name_out = value_name;
+ }
+
+ if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) {
+ return; // Already processed this OrtValueInfo.
+ }
- return Ort::Status{nullptr};
- };
-
- size_t num_nodes = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes));
-
- std::vector nodes(num_nodes);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size()));
-
- // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos
- // that will be stored in GraphProto.value_info and GraphProto.initializer.
- for (size_t i = 0; i < num_nodes; i++) {
- const OrtNode* ort_node = nodes[i];
- onnx::NodeProto* node_proto = graph_proto.add_node();
-
- const char* node_name = nullptr;
- const char* node_domain = nullptr;
- const char* node_op_type = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type));
-
- node_proto->set_name(node_name);
- node_proto->set_domain(node_domain);
- node_proto->set_op_type(node_op_type);
-
- size_t num_inputs = 0;
- size_t num_implicit_inputs = 0;
- size_t num_outputs = 0;
- size_t num_attrs = 0;
- size_t num_subgraphs = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs));
-
- // Handle node attributes
- if (num_attrs > 0) {
- std::vector ort_attrs(num_attrs);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size()));
-
- for (const OrtOpAttr* ort_attr : ort_attrs) {
- OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;
-
- Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)};
+ bool is_required_graph_input = ort_value_info.IsRequiredGraphInput();
+ bool is_optional_graph_input = ort_value_info.IsOptionalGraphInput();
+ bool is_graph_output = ort_value_info.IsGraphOutput();
+ bool is_constant_initializer = ort_value_info.IsConstantInitializer();
+ bool is_from_outer_scope = ort_value_info.IsFromOuterScope();
+
+ // Don't add graph inputs or graph outputs to GraphProto's list of value_infos.
+ // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors.
+ // For values defined in an outer scope, just add the value info but not the initializer.
+ if (is_from_outer_scope) {
+ value_infos.emplace(value_name, ort_value_info);
+ } else if (is_optional_graph_input) {
+ initializer_value_infos.emplace(value_name, ort_value_info);
+ } else if (is_constant_initializer) {
+ value_infos.emplace(value_name, ort_value_info);
+ initializer_value_infos.emplace(value_name, ort_value_info);
+ } else if (!is_required_graph_input && !is_graph_output) {
+ value_infos.emplace(value_name, ort_value_info); // This is an internal OrtValueInfo.
+ }
+ };
+
+ std::vector nodes = ort_graph.GetNodes();
+ // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos
+ // that will be stored in GraphProto.value_info and GraphProto.initializer.
+ for (const auto& ort_node : nodes) {
+ onnx::NodeProto* node_proto = graph_proto.add_node();
+
+ std::string node_name = ort_node.GetName();
+ std::string node_domain = ort_node.GetDomain();
+ std::string node_op_type = ort_node.GetOperatorType();
+
+ node_proto->set_name(node_name);
+ node_proto->set_domain(node_domain);
+ node_proto->set_op_type(node_op_type);
+
+ // Handle node attributes
+ std::vector ort_attrs = ort_node.GetAttributes();
+ for (const auto& attr : ort_attrs) {
+ OrtOpAttrType attr_type = attr.GetType();
if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) {
// ORT does not support reading subgraphs via ReadOpAttr(), so skip it.
// Can use Node_GetSubgraphs to get subgraphs.
continue;
}
- if (!attr_type_status.IsOK()) {
- // Unsupported attribute type.
- return attr_type_status;
- }
-
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto));
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(attr, *attr_proto));
}
- }
-
- // Handle node subgraphs
- if (num_subgraphs > 0) {
- std::vector ort_subgraphs(num_subgraphs);
- std::vector subgraph_attr_names(num_subgraphs);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(),
- subgraph_attr_names.data()));
-
- for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) {
- const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx];
- const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx];
+ // Handle node subgraphs
+ std::vector ort_subgraphs = ort_node.GetSubgraphs();
+ for (const auto& [subgraph_attr_name, ort_subgraph] : ort_subgraphs) {
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
onnx::GraphProto* subgraph_proto = attr_proto->mutable_g();
-
attr_proto->set_name(subgraph_attr_name);
attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH);
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto));
}
- }
-
- // Handle node inputs
- if (num_inputs > 0) {
- std::vector ort_inputs(num_inputs);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size()));
- for (const OrtValueInfo* ort_value_info : ort_inputs) {
- if (ort_value_info == nullptr) {
+ // Handle node inputs
+ std::vector ort_inputs = ort_node.GetInputs();
+ for (const auto& vi : ort_inputs) {
+ if (vi == nullptr) {
// missing optional input.
node_proto->add_input("");
continue;
}
- const char* value_name = nullptr;
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name));
-
- node_proto->add_input(value_name);
+ std::optional value_name;
+ value_name.emplace();
+ collect_value_info(vi, value_name);
+ node_proto->add_input(*value_name);
}
- }
-
- // Handle implicit inputs to this node.
- if (num_implicit_inputs > 0) {
- std::vector ort_implicit_inputs(num_implicit_inputs);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(),
- ort_implicit_inputs.size()));
- for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) {
- assert(ort_value_info != nullptr);
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr));
+ // Handle implicit inputs to this node.
+ std::vector ort_implicit_inputs = ort_node.GetImplicitInputs();
+ for (const auto& vi : ort_implicit_inputs) {
+ assert(vi != nullptr);
+ std::optional value_name;
+ collect_value_info(vi, value_name);
}
- }
-
- // Handle node outputs
- if (num_outputs > 0) {
- std::vector ort_outputs(num_outputs);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size()));
- for (const OrtValueInfo* ort_value_info : ort_outputs) {
- if (ort_value_info == nullptr) {
+ // Handle node outputs
+ std::vector ort_outputs = ort_node.GetOutputs();
+ for (const auto& vi : ort_outputs) {
+ if (vi == nullptr) {
// missing optional output.
node_proto->add_output("");
continue;
}
- const char* value_name = nullptr;
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name));
-
- node_proto->add_output(value_name);
+ std::optional value_name;
+ value_name.emplace();
+ collect_value_info(vi, value_name);
+ node_proto->add_output(*value_name);
}
}
- }
-
- // Add value_infos to GraphProto as ValueInfoProto objects.
- for (const std::pair& entry : value_infos) {
- onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add();
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto));
- }
- // Add initializers to GraphProto as TensorProto objects.
- for (const std::pair& entry : initializer_value_infos) {
- const OrtValueInfo* initializer_value_info = entry.second;
- std::string initializer_name = std::string{entry.first}; // Need a null-terminated string.
- std::vector initializer_dims;
- std::vector initializer_sym_dims;
- ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false,
- initializer_elem_type, initializer_dims,
- initializer_sym_dims));
-
- onnx::TensorProto* tensor_proto = graph_proto.add_initializer();
- tensor_proto->set_name(initializer_name);
- tensor_proto->set_data_type(initializer_elem_type);
-
- auto* tensor_proto_dims = tensor_proto->mutable_dims();
- for (int64_t dim : initializer_dims) {
- tensor_proto_dims->Add(dim);
+ // Add value_infos to GraphProto as ValueInfoProto objects.
+ for (const auto& [value_name, value_info] : value_infos) {
+ onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add();
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto));
}
- const OrtValue* ort_value = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value));
+ // Add initializers to GraphProto as TensorProto objects.
+ for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) {
+ std::vector initializer_dims;
+ std::vector initializer_sym_dims;
+ ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false,
+ initializer_elem_type, initializer_dims,
+ initializer_sym_dims));
+
+ onnx::TensorProto* tensor_proto = graph_proto.add_initializer();
+ tensor_proto->set_name(initializer_name);
+ tensor_proto->set_data_type(initializer_elem_type);
+
+ auto* tensor_proto_dims = tensor_proto->mutable_dims();
+ for (int64_t dim : initializer_dims) {
+ tensor_proto_dims->Add(dim);
+ }
- const void* data = nullptr;
- size_t data_bytes = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes));
+ Ort::ConstValue ort_value{nullptr};
+ ORT_EP_UTILS_C_RETURN_IF_ERROR(initializer_value_info.GetInitializer(ort_value));
- std::string ext_location;
- int64_t ext_offset = 0;
- bool is_external = false;
+ assert(ort_value.IsTensor());
+ const void* data = ort_value.GetTensorRawData();
+ const size_t data_bytes = ort_value.GetTensorSizeInBytes();
- if (handle_initializer_data_func != nullptr) {
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes,
- is_external, ext_location, ext_offset));
- }
+ std::string ext_location;
+ int64_t ext_offset = 0;
+ bool is_external = false;
+
+ if (handle_initializer_data_func != nullptr) {
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes,
+ is_external, ext_location, ext_offset));
+ }
- if (is_external) {
- tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL);
- auto* ext_data_entries = tensor_proto->mutable_external_data();
- onnx::StringStringEntryProto* location_entry = ext_data_entries->Add();
- onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add();
- onnx::StringStringEntryProto* length_entry = ext_data_entries->Add();
-
- location_entry->set_key("location");
- location_entry->set_value(ext_location);
- offset_entry->set_key("offset");
- offset_entry->set_value(std::to_string(ext_offset));
- length_entry->set_key("length");
- length_entry->set_value(std::to_string(data_bytes));
- } else {
- // User wants to store data inline the TensorProto's raw_data
- tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT);
- tensor_proto->set_raw_data(data, data_bytes);
+ if (is_external) {
+ tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL);
+ auto* ext_data_entries = tensor_proto->mutable_external_data();
+ onnx::StringStringEntryProto* location_entry = ext_data_entries->Add();
+ onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add();
+ onnx::StringStringEntryProto* length_entry = ext_data_entries->Add();
+
+ location_entry->set_key("location");
+ location_entry->set_value(ext_location);
+ offset_entry->set_key("offset");
+ offset_entry->set_value(std::to_string(ext_offset));
+ length_entry->set_key("length");
+ length_entry->set_value(std::to_string(data_bytes));
+ } else {
+ // User wants to store data inline the TensorProto's raw_data
+ tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT);
+ tensor_proto->set_raw_data(data, data_bytes);
+ }
}
+ } catch (const Ort::Exception& ex) {
+ return Ort::Status{ex};
+ } catch (const std::exception& ex) {
+ return Ort::Status{ex.what(), ORT_FAIL};
}
return Ort::Status{nullptr};
}
-Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
+Ort::Status OrtGraphToProto(const OrtGraph& graph,
onnx::ModelProto& model_proto,
HandleInitializerDataFunc handle_initializer_data_func) {
- const OrtApi& ort_api = Ort::GetApi();
-
- // Check that OrtGraph is a top-level graph (no parent node).
- const OrtNode* parent_node = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node));
- ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto");
-
- // Set model description.
- model_proto.set_doc_string("Serialized from OrtGraph");
- model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto");
-
- // Set ir version.
- int64_t ir_version = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version));
- model_proto.set_ir_version(ir_version);
-
- // Set operator sets.
- size_t num_operator_sets = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets));
- ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set.");
-
- std::vector domains(num_operator_sets, nullptr);
- std::vector opset_versions(num_operator_sets);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(),
- num_operator_sets));
-
- auto* operator_sets = model_proto.mutable_opset_import();
-
- for (size_t i = 0; i < num_operator_sets; ++i) {
- onnx::OperatorSetIdProto* operator_set = operator_sets->Add();
- operator_set->set_domain(domains[i]);
- operator_set->set_version(opset_versions[i]);
- }
+ try {
+ // Check that OrtGraph is a top-level graph (no parent node).
+ Ort::ConstGraph ort_graph{&graph};
+ Ort::ConstNode parent_node = ort_graph.GetParentNode();
+ ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, "Cannot serialize nested OrtGraph into a ModelProto");
+
+ // Set model description.
+ model_proto.set_doc_string("Serialized from OrtGraph");
+ model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto");
+
+ // Set ir version.
+ int64_t ir_version = ort_graph.GetOnnxIRVersion();
+ model_proto.set_ir_version(ir_version);
+
+ // Set operator sets.
+ std::vector op_sets = ort_graph.GetOperatorSets();
+ ORT_EP_UTILS_C_RETURN_IF(op_sets.empty(), "OrtGraph should have at least one operator set.");
+
+ auto* operator_sets = model_proto.mutable_opset_import();
+
+ for (const auto& op_set : op_sets) {
+ onnx::OperatorSetIdProto* operator_set = operator_sets->Add();
+ operator_set->set_domain(op_set.domain);
+ operator_set->set_version(op_set.version);
+ }
- model_proto.clear_graph();
- onnx::GraphProto* graph_proto = model_proto.mutable_graph();
+ model_proto.clear_graph();
+ onnx::GraphProto* graph_proto = model_proto.mutable_graph();
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_graph, *graph_proto, handle_initializer_data_func));
- ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func));
+ } catch (const Ort::Exception& ex) {
+ return Ort::Status(ex);
+ } catch (const std::exception& ex) {
+ return Ort::Status(ex.what(), ORT_EP_FAIL);
+ }
return Ort::Status{nullptr};
}
-static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info,
+static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi,
bool get_symbolic_dims,
/*out*/ ONNXTensorElementDataType& elem_type,
/*out*/ std::vector& dims,
/*out*/ std::vector& symbolic_dims) {
- const OrtApi& ort_api = Ort::GetApi();
-
- const OrtTypeInfo* ort_type_info = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info));
-
- ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type));
- ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor");
+ try {
+ Ort::ConstTypeInfo ort_type_info = vi.TypeInfo();
+ ONNXType ort_onnx_type = ort_type_info.GetONNXType();
+ ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor");
- const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr;
- ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type));
-
- size_t num_dims = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims));
+ Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo();
+ ONNXTensorElementDataType ort_elem_type = ort_type_shape.GetElementType();
- std::vector ort_dims(num_dims, 0);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size()));
+ size_t num_dims = ort_type_shape.GetDimensionsCount();
+ std::vector ort_dims = ort_type_shape.GetShape();
- elem_type = ort_elem_type;
- dims = std::move(ort_dims);
+ elem_type = ort_elem_type;
+ dims = std::move(ort_dims);
- if (get_symbolic_dims) {
- std::vector ort_dim_syms(num_dims, nullptr);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(),
- ort_dim_syms.size()));
+ if (get_symbolic_dims) {
+ std::vector ort_dim_syms(num_dims, nullptr);
+ ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size());
- symbolic_dims.reserve(num_dims);
- for (const char* sym_dim : ort_dim_syms) {
- symbolic_dims.push_back(sym_dim);
+ symbolic_dims.reserve(num_dims);
+ for (const char* sym_dim : ort_dim_syms) {
+ symbolic_dims.push_back(sym_dim);
+ }
}
+ } catch (const Ort::Exception& ex) {
+ return Ort::Status{ex};
+ } catch (const std::exception& ex) {
+ return Ort::Status{ex.what(), ORT_EP_FAIL};
}
-
return Ort::Status{nullptr};
}
// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape).
-static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
+static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info,
onnx::ValueInfoProto& value_info_proto) {
- const OrtApi& ort_api = Ort::GetApi();
-
std::vector ort_dims;
std::vector ort_dim_syms;
ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
@@ -620,9 +536,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true,
ort_elem_type, ort_dims, ort_dim_syms));
- const char* value_name = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name));
- value_info_proto.set_name(value_name);
+ value_info_proto.set_name(ort_value_info.GetName());
onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type();
type_proto_tensor->set_elem_type(ort_elem_type);
@@ -652,213 +566,149 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
return Ort::Status{nullptr};
}
-static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
- const OrtApi& ort_api = Ort::GetApi();
-
- const char* attr_name = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name));
- attr_proto.set_name(attr_name);
+static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr attr, onnx::AttributeProto& attr_proto) {
+ try {
+ std::string attr_name = attr.GetName();
+ attr_proto.set_name(attr_name);
- size_t total_attr_bytes = 0;
- OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type));
+ OrtOpAttrType attr_type = attr.GetType();
- switch (attr_type) {
- case OrtOpAttrType::ORT_OP_ATTR_INT: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_INT);
-
- int64_t i_val = 0;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes));
- attr_proto.set_i(i_val);
- break;
- }
- case OrtOpAttrType::ORT_OP_ATTR_INTS: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS);
-
- // First call to ReadOpAttr gets the total byte size. Second call reads the data.
- Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)};
- std::vector i_vals(total_attr_bytes / sizeof(int64_t));
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes,
- &total_attr_bytes));
-
- auto* ints = attr_proto.mutable_ints();
- for (int64_t val : i_vals) {
- ints->Add(val);
+ switch (attr_type) {
+ case OrtOpAttrType::ORT_OP_ATTR_INT: {
+ int64_t i_val = 0;
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(i_val));
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_INT);
+ attr_proto.set_i(i_val);
+ break;
}
- break;
- }
- case OrtOpAttrType::ORT_OP_ATTR_FLOAT: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT);
-
- float f_val = 0.0f;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes));
- attr_proto.set_f(f_val);
- break;
- }
- case OrtOpAttrType::ORT_OP_ATTR_FLOATS: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS);
-
- // First call to ReadOpAttr gets the total byte size. Second call reads the data.
- Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)};
- std::vector f_vals(total_attr_bytes / sizeof(float));
-
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes,
- &total_attr_bytes));
-
- auto* floats = attr_proto.mutable_floats();
- for (float val : f_vals) {
- floats->Add(val);
+ case OrtOpAttrType::ORT_OP_ATTR_INTS: {
+ std::vector i_vals;
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(i_vals));
+ auto* ints = attr_proto.mutable_ints();
+ ints->Assign(i_vals.begin(), i_vals.end());
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS);
+ break;
}
- break;
- }
- case OrtOpAttrType::ORT_OP_ATTR_STRING: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING);
-
- // First call to ReadOpAttr gets the total byte size. Second call reads the data.
- Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)};
- std::string* str = attr_proto.mutable_s();
-
- str->resize(total_attr_bytes);
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes,
- &total_attr_bytes));
-
- str->resize(total_attr_bytes);
- break;
- }
- case OrtOpAttrType::ORT_OP_ATTR_STRINGS: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS);
-
- // First call to ReadOpAttr gets the total byte size. Second call reads the data.
- Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)};
- std::vector chars(total_attr_bytes, '\0');
-
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes,
- &total_attr_bytes));
-
- auto* strs = attr_proto.mutable_strings();
-
- // Strings are all in a single buffer, each separated with a '\0'.
- // Extract each string and add it to the STRINGS attribute array.
- char* at = chars.data();
- char* end = at + chars.size();
-
- while (at < end) {
- char* str_begin = at;
-
- while (*at && at < end) {
- at++;
- }
-
- strs->Add()->assign(str_begin, at - str_begin);
- if (at < end) {
- assert(*at == '\0');
- at++; // Skip '\0' to get to the beginning of the next string.
- }
+ case OrtOpAttrType::ORT_OP_ATTR_FLOAT: {
+ float f_val = 0.0f;
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(f_val));
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT);
+ attr_proto.set_f(f_val);
+ break;
}
+ case OrtOpAttrType::ORT_OP_ATTR_FLOATS: {
+ std::vector f_vals;
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(f_vals));
+ auto* floats = attr_proto.mutable_floats();
+ floats->Assign(f_vals.begin(), f_vals.end());
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS);
+ break;
+ }
+ case OrtOpAttrType::ORT_OP_ATTR_STRING: {
+ std::string* str = attr_proto.mutable_s();
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(*str));
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING);
+ break;
+ }
+ case OrtOpAttrType::ORT_OP_ATTR_STRINGS: {
+ std::vector result;
+ ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(result));
+ auto* strs = attr_proto.mutable_strings();
+ strs->Assign(result.begin(), result.end());
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS);
+ break;
+ }
+ case OrtOpAttrType::ORT_OP_ATTR_TENSOR: {
+ attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR);
+
+ onnx::TensorProto tensor_proto;
+
+ // TensorProto as an attribute value doesn't require a name.
+
+ Ort::Value tensor;
+ ORT_EP_UTILS_C_RETURN_IF_ERROR(attr.GetTensorAttributeAsOrtValue(tensor));
+
+ // Get tensor type and shape info
+ Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo();
+
+ // Get tensor type
+ ONNXTensorElementDataType element_type = type_shape_info.GetElementType();
+
+ switch (element_type) {
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32);
+ break;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: {
+ tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64);
+ break;
+ }
+ default: {
+ std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type));
+ return Ort::Status(err_msg.c_str(), ORT_FAIL);
+ }
+ }
- break;
- }
- case OrtOpAttrType::ORT_OP_ATTR_TENSOR: {
- attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR);
-
- onnx::TensorProto tensor_proto;
-
- // TensorProto as an attribute value doesn't require a name.
-
- OrtValue* ort_value = nullptr;
- ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value));
+ auto shape = type_shape_info.GetShape();
- Ort::Value tensor(ort_value);
+ for (auto& dim : shape) {
+ tensor_proto.add_dims(dim);
+ }
- // Get tensor type and shape info
- Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo();
+ const void* data = tensor.GetTensorRawData();
+ const size_t data_bytes = tensor.GetTensorSizeInBytes();
- // Get tensor type
- ONNXTensorElementDataType element_type = type_shape_info.GetElementType();
+ // Copy the Ortvalue to TensorProto as raw data
+ tensor_proto.set_raw_data(data, data_bytes);
- size_t element_size = 0;
- switch (element_type) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT);
- element_size = sizeof(float);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8);
- element_size = sizeof(uint8_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8);
- element_size = sizeof(int8_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16);
- element_size = sizeof(uint16_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16);
- element_size = sizeof(int16_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32);
- element_size = sizeof(int32_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64);
- element_size = sizeof(int64_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL);
- element_size = sizeof(bool);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE);
- element_size = sizeof(double);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32);
- element_size = sizeof(uint32_t);
- break;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: {
- tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64);
- element_size = sizeof(uint64_t);
- break;
- }
- default: {
- std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type));
- return Ort::Status(err_msg.c_str(), ORT_FAIL);
- }
+ *(attr_proto.mutable_t()) = std::move(tensor_proto);
+ break;
}
-
- auto shape = type_shape_info.GetShape();
-
- for (auto& dim : shape) {
- tensor_proto.add_dims(dim);
+ default: {
+ std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type));
+ return Ort::Status(err_msg.c_str(), ORT_FAIL);
}
-
- size_t element_count = type_shape_info.GetElementCount();
- size_t data_bytes = element_count * element_size;
- const void* data = tensor.GetTensorData();
-
- // Copy the Ortvalue to TensorProto as raw data
- tensor_proto.set_raw_data(data, data_bytes);
-
- *(attr_proto.mutable_t()) = std::move(tensor_proto);
- break;
- }
- default: {
- std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type));
- return Ort::Status(err_msg.c_str(), ORT_FAIL);
}
+ } catch (const Ort::Exception& ex) {
+ return Ort::Status{ex};
+ } catch (const std::exception& ex) {
+ return Ort::Status{ex.what(), ORT_FAIL};
}
return Ort::Status{nullptr};
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 72e8a3ca1103c..8561de9c8c3b9 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -533,6 +533,57 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e
_Out_ size_t* num_selected,
_In_ void* state);
+/** \brief Function called by ORT to write a buffer to a custom destination (e.g., file, stream, etc.).
+ *
+ * \param state Opaque pointer holding the user's state.
+ * \param buffer The buffer to write.
+ * \param buffer_num_bytes The size of the buffer in bytes.
+ *
+ * \return OrtStatus* Write status. Return nullptr on success.
+ * Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
+ * ORT will release the OrtStatus* if not null.
+ */
+typedef OrtStatus*(ORT_API_CALL* OrtWriteBufferFunc)(_In_ void* state,
+ _In_ const void* buffer,
+ _In_ size_t buffer_num_bytes);
+
+/** \brief Function called by ORT to allow user to specify how an initializer should be saved, that is, either
+ * written to an external file or stored within the model. ORT calls this function for every initializer when
+ * generating a model.
+ *
+ * If the function implementation sets the `new_external_info` output parameter to NULL, ORT stores the initializer data
+ * within the generated model.
+ *
+ * Otherwise, if the function implementation sets `new_external_info` to a valid OrtExternalInitializerInfo instance,
+ * ORT assumes that this function stores the initializer data in a file. In this case, ORT configures the model's
+ * initializer to point to the location specified by the `new_external_info` output parameter.
+ *
+ * \param[in] state Opaque pointer holding the user's state.
+ * \param[in] initializer_name The initializer's name as a null-terminated string.
+ * \param[in] initializer_value OrtValue containing the initializer's data, type, and shape.
+ * \param[in] external_info If the initializer is originally stored in an external file, `external_info` contains
+ * the file path, file offset, and the data's byte size within the file. Otherwise,
+ * `external_info` is NULL if the initializer is not originally stored in a file.
+ * \param[out] new_external_info Output parameter set to a new OrtExternalInitializerInfo instance indicating the
+ * location where the function implementation stored the initializer data.
+ * The function implementation must use `OrtApi::CreateExternalInitializerInfo()` to
+ * create the instance.
+ * If the function implementation sets `new_external_info` to NULL,
+ * ORT stores the initializers within the model.
+ *
+ * \note ORT takes ownership of the `new_external_info` output parameter.
+ *
+ * \return OrtStatus* Write status. Return nullptr on success.
+ * Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
+ * ORT will release the OrtStatus* if not null.
+ */
+typedef OrtStatus*(ORT_API_CALL* OrtGetInitializerLocationFunc)(
+ _In_ void* state,
+ _In_ const char* initializer_name,
+ _In_ const OrtValue* initializer_value,
+ _In_opt_ const OrtExternalInitializerInfo* external_info,
+ _Outptr_result_maybenull_ OrtExternalInitializerInfo** new_external_info);
+
/** \brief Algorithm to use for cuDNN Convolution Op
*/
typedef enum OrtCudnnConvAlgoSearch {
@@ -6507,6 +6558,26 @@ struct OrtApi {
_In_ size_t num_ep_devices,
_In_ const char* compatibility_info,
_Out_ OrtCompiledModelCompatibility* out_status);
+
+ /// \name OrtExternalInitializerInfo
+ /// @{
+
+ /** \brief Creates an OrtExternalInitializerInfo instance.
+ *
+ * \param[in] filepath The relative path to the file that stores the initializer's data. ORT copies this path string.
+ * \param[in] file_offset The byte offset where the initializer's data is stored within the file.
+ * \param[in] byte_size The size in bytes of the initializer's data within the file.
+ * \param[out] out Output parameter set to the new OrtExternalInitializerInfo instance.
+ * Must be released by calling ReleaseExternalInitializerInfo().
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset,
+ _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out);
+
+ /// @}
};
/*
@@ -7074,6 +7145,9 @@ struct OrtCompileApi {
* ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling
* CompileModel.
*
+ * \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use
+ * ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations.
+ *
* \param[in] env OrtEnv object.
* \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions.
* \param[out] out The created OrtModelCompilationOptions instance.
@@ -7230,7 +7304,7 @@ struct OrtCompileApi {
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options,
- size_t flags);
+ uint32_t flags);
/** Sets information related to EP context binary file.
*
@@ -7249,6 +7323,56 @@ struct OrtCompileApi {
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_directory,
_In_ const ORTCHAR_T* model_name);
+
+ /** Set the graph optimization level.
+ *
+ * \param[in] model_compile_options The OrtModelCompilationOptions instance.
+ * \param[in] graph_optimization_level The graph optimization level.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ GraphOptimizationLevel graph_optimization_level);
+
+ /** \brief Sets a OrtWriteBufferFunc function that is called by ORT to write out the output model's serialized
+ * ONNX bytes.
+ *
+ * The provided write function may be called repeatedly until then entire output model has been written out. Each call
+ * to the write function is expected to consume the entire input buffer.
+ *
+ * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions
+ * that begin with ModelCompilationOptions_SetOutputModel____.
+ *
+ * \param[in] model_compile_options The OrtModelCompilationOptions instance.
+ * \param[in] write_func The OrtWriteBufferFunc function called by ORT when writing out the model.
+ * \param[in] state Opaque state passed as the first argument to OrtWriteBufferFunc. Can be NULL.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelWriteFunc,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ OrtWriteBufferFunc write_func, _In_ void* state);
+
+ /** \brief Sets a OrtGetInitializerLocationFunc function that is called by ORT for every initializer in the generated
+ * model. Allows implementer to specify whether initializers should be stored within the model or externally.
+ *
+ * \param[in] model_compile_options The OrtModelCompilationOptions instance.
+ * \param[in] get_initializer_location_func The OrtGetInitializerLocationFunc function called by ORT when
+ * to determine the location of the initializer.
+ * \param[in] state Opaque state passed as the first argument to OrtGetInitializerLocationFunc. Can be NULL.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 981c70ab8b954..9fa7915679f62 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -52,6 +52,7 @@ namespace Ort {
* If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort()
*/
struct Exception : std::exception {
+ Exception(const std::string& string, OrtErrorCode code) : message_{string}, code_{code} {}
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
@@ -549,33 +550,35 @@ namespace detail {
inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); }
ORT_DEFINE_RELEASE(Allocator);
-ORT_DEFINE_RELEASE(MemoryInfo);
+ORT_DEFINE_RELEASE(ArenaCfg);
ORT_DEFINE_RELEASE(CustomOpDomain);
-ORT_DEFINE_RELEASE(ThreadingOptions);
ORT_DEFINE_RELEASE(Env);
-ORT_DEFINE_RELEASE(RunOptions);
+ORT_DEFINE_RELEASE(ExternalInitializerInfo);
+ORT_DEFINE_RELEASE(Graph);
+ORT_DEFINE_RELEASE(IoBinding);
+ORT_DEFINE_RELEASE(KernelInfo);
+ORT_DEFINE_RELEASE(KeyValuePairs);
ORT_DEFINE_RELEASE(LoraAdapter);
+ORT_DEFINE_RELEASE(MemoryInfo);
+ORT_DEFINE_RELEASE(MapTypeInfo);
+ORT_DEFINE_RELEASE(Model);
+ORT_DEFINE_RELEASE(ModelMetadata);
+ORT_DEFINE_RELEASE(Node);
+ORT_DEFINE_RELEASE(Op);
+ORT_DEFINE_RELEASE(OpAttr);
+ORT_DEFINE_RELEASE(PrepackedWeightsContainer);
+ORT_DEFINE_RELEASE(RunOptions);
ORT_DEFINE_RELEASE(Session);
ORT_DEFINE_RELEASE(SessionOptions);
-ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
ORT_DEFINE_RELEASE(SequenceTypeInfo);
-ORT_DEFINE_RELEASE(MapTypeInfo);
-ORT_DEFINE_RELEASE(TypeInfo);
-ORT_DEFINE_RELEASE(Value);
-ORT_DEFINE_RELEASE(ModelMetadata);
-ORT_DEFINE_RELEASE(IoBinding);
-ORT_DEFINE_RELEASE(ArenaCfg);
ORT_DEFINE_RELEASE(Status);
ORT_DEFINE_RELEASE(SyncStream);
-ORT_DEFINE_RELEASE(OpAttr);
-ORT_DEFINE_RELEASE(Op);
-ORT_DEFINE_RELEASE(KernelInfo);
+ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
+ORT_DEFINE_RELEASE(ThreadingOptions);
+ORT_DEFINE_RELEASE(TypeInfo);
+ORT_DEFINE_RELEASE(Value);
ORT_DEFINE_RELEASE(ValueInfo);
-ORT_DEFINE_RELEASE(Node);
-ORT_DEFINE_RELEASE(Graph);
-ORT_DEFINE_RELEASE(Model);
-ORT_DEFINE_RELEASE(KeyValuePairs);
-ORT_DEFINE_RELEASE(PrepackedWeightsContainer);
+
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
@@ -701,6 +704,7 @@ struct AllocatedFree {
struct AllocatorWithDefaultOptions;
struct Env;
struct EpDevice;
+struct ExternalInitializerInfo;
struct Graph;
struct Model;
struct Node;
@@ -818,6 +822,46 @@ struct PrepackedWeightsContainer : detail::Base {
PrepackedWeightsContainer();
};
+namespace detail {
+template
+struct ConstExternalInitializerInfoImpl : Base {
+ using B = Base;
+ using B::B;
+
+ // Wraps OrtApi::ExternalInitializerInfo_GetFilePath
+ const std::basic_string GetFilePath() const;
+ // Wraps OrtApi::ExternalInitializerInfo_GetFileOffset
+ int64_t GetFileOffset() const;
+ // Wraps OrtApi::ExternalInitializerInfo_GetByteSize
+ size_t GetByteSize() const;
+};
+} // namespace detail
+
+// Const object holder that does not own the underlying object
+using ConstExternalInitializerInfo =
+ detail::ConstExternalInitializerInfoImpl>;
+
+/** \brief Wrapper around ::OrtExternalInitializerInfo
+ *
+ */
+struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl {
+ using Base = detail::ConstExternalInitializerInfoImpl;
+ using Base::Base;
+
+ explicit ExternalInitializerInfo(std::nullptr_t) {}
+ explicit ExternalInitializerInfo(OrtExternalInitializerInfo* p)
+ : detail::ConstExternalInitializerInfoImpl{p} {}
+
+ ConstExternalInitializerInfo GetConst() const { return ConstExternalInitializerInfo{this->p_}; }
+
+ ///< Wraps OrtApi::CreateExternalInitializerInfo
+ ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size);
+
+ ///< Wrapper around CreateExternalInitializerInfo that does not throw an exception.
+ static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size,
+ /*out*/ ExternalInitializerInfo& out);
+};
+
namespace detail {
template
struct KeyValuePairsImpl : Ort::detail::Base {
@@ -1357,11 +1401,23 @@ struct ModelCompilationOptions : detail::Base {
ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath
ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path,
size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile
+
+ ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc
+ ModelCompilationOptions& SetOutputModelGetInitializerLocationFunc(
+ OrtGetInitializerLocationFunc get_initializer_location_func,
+ void* state);
+
ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
+
+ ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc
+ ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state);
+
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
- ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
+ ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
+
+ ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel
};
/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.
@@ -2366,15 +2422,46 @@ struct ArenaCfg : detail::Base {
// Custom OPs (only needed to implement custom OPs)
//
+namespace detail {
+// Need to define a templated ConstOpAttr with const members
+template
+struct ConstOpAttrImpl : Base {
+ using B = detail::Base;
+ using B::B;
+
+ // Wraps OrtApi::OpAttr_GetName
+ std::string GetName() const;
+ // Wraps OrtApi::OpAttr_GetType
+ OrtOpAttrType GetType() const;
+
+ // Wraps OrtApi::ReadAttr for a single value
+ // This does not support Tensor Attribute
+ // Call GetTensorAttributeAsOrtValue() instead.
+ template
+ Status GetValue(R& out) const;
+
+ // Wraps OrtApi::ReadAttr for an array of values
+ template
+ Status GetValueArray(std::vector& out) const;
+ // Wraps OrtApi::OpAttr_GetTensorAttributeAsOrtValue
+ Status GetTensorAttributeAsOrtValue(Value&) const;
+};
+} // namespace detail
+
+using ConstOpAttr = detail::ConstOpAttrImpl>;
+
///
/// This struct provides life time management for custom op attribute
///
-struct OpAttr : detail::Base {
- using Base = detail::Base;
+struct OpAttr : detail::ConstOpAttrImpl {
+ using Base = detail::ConstOpAttrImpl;
using Base::Base;
+ OpAttr() = default; // Enable storing it in the container for resize()
explicit OpAttr(std::nullptr_t) {}
OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
+
+ ConstOpAttr GetConst() const { return ConstOpAttr{this->p_}; }
};
/**
@@ -2720,7 +2807,7 @@ struct ShapeInferContext {
Strings GetAttrStrings(const char* attr_name);
private:
- const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
+ ConstOpAttr GetAttrHdl(const char* attr_name) const;
const OrtApi* ort_api_;
OrtShapeInferContext* ctx_;
std::vector input_shapes_;
@@ -2871,48 +2958,114 @@ struct CustomOpBase : OrtCustomOp {
int end_ver_ = MAX_CUSTOM_OP_END_VER;
};
+// Forward declaration to resolve circular dependency
+// on ConstNode
+struct ValueInfoConsumerProducerInfo;
+
namespace detail {
template
-struct ValueInfoImpl : Ort::detail::Base {
- using B = Ort::detail::Base;
+struct ConstValueInfoImpl : Base {
+ using B = Base;
using B::B;
- std::string Name() const;
+ /// < A wrapper around OrtApi::GetValueInfoName
+ std::string GetName() const;
+ /// < A wrapper around OrtApi::GetValueInfoTypeInfo
ConstTypeInfo TypeInfo() const;
+ ///< Wraps OrtApi::ValueInfo_GetProducerNode
+ ValueInfoConsumerProducerInfo GetProducerNode() const;
+ /// < A wrapper around OrtApi::ValueInfo_GetValueConsumers
+ std::vector GetConsumers() const;
+ /// < A wrapper around OrtApi::ValueInfo_GetInitializerValue
+ Status GetInitializer(ConstValue& value) const;
+ /// < A wrapper around OrtApi::ValueInfo_GetExternalInitializerInfo
+ Status GetExternalInitializerInfo(ExternalInitializerInfo& info) const;
+ /// < A wrapper around OrtApi::ValueInfo_IsRequiredGraphInput
+ bool IsRequiredGraphInput() const;
+ /// < A wrapper around OrtApi::ValueInfo_IsOptionalGraphInput
+ bool IsOptionalGraphInput() const;
+ /// < A wrapper around OrtApi::ValueInfo_IsGraphOutput
+ bool IsGraphOutput() const;
+ /// < A wrapper around OrtApi::ValueInfo_IsConstantInitializer
+ bool IsConstantInitializer() const;
+ /// < A wrapper around OrtApi::ValueInfo_IsFromOuterScope
+ bool IsFromOuterScope() const;
};
} // namespace detail
// Const object holder that does not own the underlying object
-using ConstValueInfo = detail::ValueInfoImpl>;
+using ConstValueInfo = detail::ConstValueInfoImpl>;
/** \brief Wrapper around ::OrtValueInfo
*
*/
-struct ValueInfo : detail::ValueInfoImpl {
+struct ValueInfo : detail::ConstValueInfoImpl {
+ ValueInfo() = default; // Same thing as with nullptr
explicit ValueInfo(std::nullptr_t) {} ///< No instance is created
/// Take ownership of a pointer created by C API
- explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {}
+ explicit ValueInfo(OrtValueInfo* p) : ConstValueInfoImpl{p} {}
+#if !defined(ORT_MINIMAL_BUILD)
// Create ValueInfo for a tensor
explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info);
-
+#endif
ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; }
};
+// Forward declaration
+struct AttrNameSubgraph;
+
namespace detail {
+// Forward decl
template
-struct NodeImpl : Ort::detail::Base {
- using B = Ort::detail::Base;
+struct ConstGraphImpl;
+
+template
+struct ConstNodeImpl : Base {
+ using B = Base;
using B::B;
+
+ // GetInputs() const;
+ // GetOutputs() const;
+ // GetImplicitInputs() const;
+ // GetAttributes() const;
+ // GetSubgraphs() const;
+ // > GetGraph() const;
+ // >;
+
/** \brief Wrapper around ::OrtNode
*
*/
-struct Node : detail::NodeImpl {
- explicit Node(std::nullptr_t) {} ///< No instance is created
- explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API
+struct Node : detail::ConstNodeImpl {
+ Node() = default; // Same thing as with nullptr
+ explicit Node(std::nullptr_t) {} ///< No instance is created
+ explicit Node(OrtNode* p) : ConstNodeImpl{p} {} ///< Take ownership of a pointer created by C API
#if !defined(ORT_MINIMAL_BUILD)
Node(const std::string& operator_name, const std::string& operator_domain,
@@ -2939,22 +3092,78 @@ struct Node : detail::NodeImpl {
#endif // !defined(ORT_MINIMAL_BUILD)
};
+// Return struct for some of ValueInfo APIs.
+// Must be declared after ConstNode is available.
+struct ValueInfoConsumerProducerInfo {
+ ConstNode node;
+ // either producer output or consumer output index
+ // producer is unsigned only, output can be -1
+ int64_t index;
+};
+
+// Represents a return value for Graph::GetOperatorSets()
+struct OperatorSet {
+ std::string domain;
+ int64_t version;
+};
+
namespace detail {
template
-struct GraphImpl : Ort::detail::Base {
- using B = Ort::detail::Base;
+struct ConstGraphImpl : Base {
+ using B = Base;
+ using B::B;
+
+ // GetModelPath() const;
+ // GetOperatorSets() const;
+ // GetInputs() const;
+ // GetOutputs() const;
+ // GetInitializers() const;
+ // GetNodes() const;
+ // & nodes) const;
+ //
+struct GraphImpl : ConstGraphImpl {
+ using B = ConstGraphImpl;
using B::B;
#if !defined(ORT_MINIMAL_BUILD)
+ // & inputs);
+ // & outputs);
+ // >;
+
+// Return value for Node API
+// Must be declared after ConstGraph
+struct AttrNameSubgraph {
+ std::string attr_name;
+ ConstGraph sub_graph;
+};
+
/** \brief Wrapper around ::OrtGraph
*
*/
@@ -2962,25 +3171,26 @@ struct Graph : detail::GraphImpl {
explicit Graph(std::nullptr_t) {} ///< No instance is created
explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API
#if !defined(ORT_MINIMAL_BUILD)
+ // >;
namespace detail {
template
-struct ModelImpl : Ort::detail::Base {
+struct ModelImpl : detail::Base {
using B = Ort::detail::Base;
using B::B;
#if !defined(ORT_MINIMAL_BUILD)
+ // >;
+using UnownedModel = detail::ModelImpl>;
/** \brief Wrapper around ::OrtModel
*
@@ -2992,10 +3202,9 @@ struct Model : detail::ModelImpl {
explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API
#if !defined(ORT_MINIMAL_BUILD)
+ //< Wraps GetModelEditorApi().CreateModel()
explicit Model(const std::vector& opsets);
#endif
-
- ConstModel GetConst() const { return ConstModel{this->p_}; }
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 05c86ae4e0c58..59979189eed0f 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -571,6 +571,42 @@ inline PrepackedWeightsContainer::PrepackedWeightsContainer() {
ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_));
}
+namespace detail {
+
+template
+inline const std::basic_string ConstExternalInitializerInfoImpl::GetFilePath() const {
+ return GetApi().ExternalInitializerInfo_GetFilePath(this->p_);
+}
+
+template
+inline int64_t ConstExternalInitializerInfoImpl::GetFileOffset() const {
+ return GetApi().ExternalInitializerInfo_GetFileOffset(this->p_);
+}
+
+template
+inline size_t ConstExternalInitializerInfoImpl::GetByteSize() const {
+ return GetApi().ExternalInitializerInfo_GetByteSize(this->p_);
+}
+} // namespace detail
+
+inline ExternalInitializerInfo::ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset,
+ size_t byte_size) {
+ ThrowOnError(GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &this->p_));
+}
+
+inline Status ExternalInitializerInfo::Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size,
+ /*out*/ ExternalInitializerInfo& out) {
+ OrtExternalInitializerInfo* info = nullptr;
+ OrtStatus* status = GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &info);
+ if (status != nullptr) {
+ return Status{status};
+ }
+
+ out = ExternalInitializerInfo(info);
+
+ return Status{nullptr};
+}
+
namespace detail {
template
inline const char* KeyValuePairsImpl::GetValue(const char* key) const {
@@ -1003,6 +1039,16 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalI
return *this;
}
+inline ModelCompilationOptions&
+ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc(
+ OrtGetInitializerLocationFunc get_initializer_location_func, void* state) {
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc(
+ this->p_,
+ get_initializer_location_func,
+ state));
+ return *this;
+}
+
inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator,
@@ -1011,6 +1057,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
return *this;
}
+inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func,
+ void* state) {
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this->p_, write_func, state));
+ return *this;
+}
+
inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
bool embed_ep_context_in_model) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(
@@ -1019,11 +1071,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
return *this;
}
-inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) {
+inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags));
return *this;
}
+inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel(
+ GraphOptimizationLevel graph_optimization_level) {
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_,
+ graph_optimization_level));
+ return *this;
+}
+
namespace detail {
template
@@ -1759,7 +1818,7 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat
#if !defined(ORT_MINIMAL_BUILD)
inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) {
- ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_));
+ ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model, options, &this->p_));
}
// static
@@ -2475,6 +2534,172 @@ inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total,
ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
}
+namespace detail {
+
+template
+constexpr OrtOpAttrType TypeToAttrType();
+
+template <>
+inline constexpr OrtOpAttrType TypeToAttrType() {
+ return OrtOpAttrType::ORT_OP_ATTR_INT;
+}
+
+template <>
+inline constexpr OrtOpAttrType TypeToAttrType() {
+ return OrtOpAttrType::ORT_OP_ATTR_FLOAT;
+}
+
+template
+inline constexpr OrtOpAttrType TypeToAttrsType();
+
+template <>
+inline constexpr OrtOpAttrType TypeToAttrsType() {
+ return OrtOpAttrType::ORT_OP_ATTR_INTS;
+}
+
+template <>
+inline constexpr OrtOpAttrType TypeToAttrsType() {
+ return OrtOpAttrType::ORT_OP_ATTR_FLOATS;
+}
+
+inline Status CheckAttrType(const OrtOpAttr* attr, OrtOpAttrType requested_type) {
+ OrtOpAttrType type;
+ Ort::Status status(GetApi().OpAttr_GetType(attr, &type));
+ if (!status.IsOK()) return status;
+ if (requested_type != type) {
+ std::string msg = "Attribute type mismatch: expected " + std::to_string(requested_type) +
+ ", but got " + std::to_string(type);
+ return Ort::Status(msg.c_str(), OrtErrorCode::ORT_INVALID_ARGUMENT);
+ }
+ return Ort::Status{};
+}
+
+inline size_t GetDataSize(const OrtOpAttr* attr, OrtOpAttrType attr_type) {
+ size_t result{};
+ // Ignore the status here because we check the data type so the error should only be about
+ // the size
+ [[maybe_unused]] Status status{GetApi().ReadOpAttr(attr, attr_type, nullptr, 0, &result)};
+ return result;
+}
+
+template
+Ort::Status GetNumericValue(const OrtOpAttr* attr, T& out) {
+ static_assert(std::is_arithmetic::value, "T must be an arithmetic type");
+ size_t size{};
+ return Ort::Status{GetApi().ReadOpAttr(attr, TypeToAttrType(), &out, sizeof(out), &size)};
+}
+
+template
+struct GetValueImpl {
+ static Status GetValue(const OrtOpAttr* attr, T& out) {
+ return GetNumericValue(attr, out);
+ }
+ static Status GetValues(const OrtOpAttr* attr, std::vector& out) {
+ // Api deficiency when it comes to value arrays. It is not possible
+ // to tell if the error is due to the type mismatch or the size
+ // so we check the type first, and then ignore the status of the size check
+ constexpr auto deduced_type = TypeToAttrsType();
+ auto status = CheckAttrType(attr, deduced_type);
+ if (!status.IsOK()) return status;
+ auto size = GetDataSize(attr, deduced_type);
+ std::vector result;
+ if (size > 0) {
+ result.resize(size / sizeof(T));
+ status = Status{GetApi().ReadOpAttr(
+ attr, deduced_type, result.data(), size, &size)};
+ if (!status.IsOK()) return status;
+ }
+ out.swap(result);
+ return status;
+ }
+};
+
+// Create GetValueImpl specializations for std::string
+template <>
+struct GetValueImpl {
+ static Status GetValue(const OrtOpAttr* attr, std::string& out) {
+ // Api deficiency when it comes to value arrays. It is not possible
+ // to tell if the error is due to the type mismatch or the size
+ // so we check the type first, and then ignore the status of the size check
+ auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRING);
+ if (!status.IsOK()) return status;
+ auto size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRING);
+ std::string result;
+ if (size > 0) {
+ result.resize(size);
+ // some compilers in use do not support std::string::data() non-const
+ auto* buffer = &result[0];
+ status = Status{GetApi().ReadOpAttr(
+ attr, OrtOpAttrType::ORT_OP_ATTR_STRING, buffer, size, &size)};
+ if (!status.IsOK()) return status;
+ }
+ out.swap(result);
+ return status;
+ }
+ static Status GetValues(const OrtOpAttr* attr, std::vector& out) {
+ auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS);
+ if (!status.IsOK()) return status;
+
+ std::vector result;
+ size_t total_buffer_size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS);
+ if (total_buffer_size > 0) {
+ // Create a temporary buffer to hold the string data
+ std::vector buffer(total_buffer_size);
+ status = Status{GetApi().ReadOpAttr(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS, buffer.data(),
+ total_buffer_size, &total_buffer_size)};
+ if (!status.IsOK()) return status;
+
+ const char* data = buffer.data();
+ const char* end = data + total_buffer_size;
+ while (data < end) {
+ result.emplace_back(data);
+ data += result.back().size() + 1; // Move past the null terminator
+ }
+ }
+ out.swap(result);
+ return status;
+ }
+};
+
+template
+template
+inline Status ConstOpAttrImpl::GetValue(R& out) const {
+ return GetValueImpl::GetValue(this->p_, out);
+}
+
+template
+template
+inline Status ConstOpAttrImpl::GetValueArray(std::vector& out) const {
+ return GetValueImpl::GetValues(this->p_, out);
+}
+
+template
+inline Status ConstOpAttrImpl::GetTensorAttributeAsOrtValue(Value& out) const {
+ OrtValue* tensor_value = nullptr;
+ auto status = Status(GetApi().OpAttr_GetTensorAttributeAsOrtValue(this->p_, &tensor_value));
+ if (!status.IsOK()) return status;
+ out = Value{tensor_value};
+ return status;
+}
+
+template
+inline std::string ConstOpAttrImpl::GetName() const {
+ const char* name = nullptr;
+ ThrowOnError(GetApi().OpAttr_GetName(this->p_, &name));
+ if (name != nullptr) {
+ return name;
+ }
+ return {};
+}
+
+template
+inline OrtOpAttrType ConstOpAttrImpl::GetType() const {
+ OrtOpAttrType type;
+ ThrowOnError(GetApi().OpAttr_GetType(this->p_, &type));
+ return type;
+}
+} // namespace detail
+
inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
}
@@ -2775,115 +3000,69 @@ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shap
}
inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
- const auto* attr = GetAttrHdl(attr_name);
- int64_t i = {};
- size_t out = {};
- Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
- return i;
+ auto attr = GetAttrHdl(attr_name);
+ int64_t value;
+ Status status = attr.GetValue(value);
+ if (!status.IsOK()) {
+ ORT_CXX_API_THROW("Getting int attribute failed: " + status.GetErrorMessage(), status.GetErrorCode());
+ }
+ return value;
}
inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
- const auto* attr = GetAttrHdl(attr_name);
- int64_t i = {};
- size_t out = {};
- // first call to get the bytes needed
- // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
- // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
- // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
- auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
- if (status) {
- size_t num_i = out / sizeof(int64_t);
- ShapeInferContext::Ints ints(num_i, 0);
- Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
- return ints;
- } else {
- if (out == 0u) {
- return {};
- }
- return {i};
+ auto attr = GetAttrHdl(attr_name);
+ ShapeInferContext::Ints result;
+ auto status = attr.GetValueArray(result);
+ if (!status.IsOK()) {
+ ORT_CXX_API_THROW("Getting ints attribute failed: " + status.GetErrorMessage(), status.GetErrorCode());
}
+ return result;
}
inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
- const auto* attr = GetAttrHdl(attr_name);
- float f = {};
- size_t out = {};
- Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
- return f;
+ auto attr = GetAttrHdl(attr_name);
+ float value;
+ Status status = attr.GetValue(value);
+ if (!status.IsOK()) {
+ ORT_CXX_API_THROW("Getting float attribute failed: " + status.GetErrorMessage(), status.GetErrorCode());
+ }
+ return value;
}
inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
- const auto* attr = GetAttrHdl(attr_name);
- float f = {};
- size_t out = {};
- // first call to get the bytes needed
- // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
- // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
- // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
- auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
- if (status) {
- size_t num_f = out / sizeof(float);
- ShapeInferContext::Floats floats(num_f, 0);
- Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
- return floats;
- } else {
- if (out == 0u) {
- return {};
- }
- return {f};
+ auto attr = GetAttrHdl(attr_name);
+ ShapeInferContext::Floats result;
+ auto status = attr.GetValueArray(result);
+ if (!status.IsOK()) {
+ ORT_CXX_API_THROW("Getting floats attribute failed: " + status.GetErrorMessage(), status.GetErrorCode());
}
+ return result;
}
inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
- const auto* attr = GetAttrHdl(attr_name);
- char c = {};
- size_t out = {};
- // first call to get the bytes needed
- auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
- if (status) {
- std::vector chars(out, '\0');
- Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
- return std::string{chars.data(), out};
- } else {
- return {c};
+ auto attr = GetAttrHdl(attr_name);
+ std::string value;
+ Status status = attr.GetValue(value);
+ if (!status.IsOK()) {
+ ORT_CXX_API_THROW("Getting string attribute failed: " + status.GetErrorMessage(), status.GetErrorCode());
}
+ return value;
}
inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
- const auto* attr = GetAttrHdl(attr_name);
- char c = {};
- size_t out = {};
- // first call to get the bytes needed
- // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure.
- // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success).
- // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}.
- auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
- if (status) {
- std::vector chars(out, '\0');
- Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
- ShapeInferContext::Strings strings;
- char* char_st = chars.data();
- char* char_ed = char_st + out;
- while (char_st < char_ed) {
- strings.emplace_back(char_st);
- while (*char_st != '\0') {
- char_st++;
- }
- char_st++;
- }
- return strings;
- } else {
- if (out == 0u) {
- return {};
- }
- return {std::string{c}};
+ auto attr = GetAttrHdl(attr_name);
+ ShapeInferContext::Strings result;
+ auto status = attr.GetValueArray(result);
+ if (!status.IsOK()) {
+ ORT_CXX_API_THROW("Getting strings attribute failed: " + status.GetErrorMessage(), status.GetErrorCode());
}
+ return result;
}
-inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
+inline ConstOpAttr ShapeInferContext::GetAttrHdl(const char* attr_name) const {
const OrtOpAttr* attr_hdl = {};
Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
- return attr_hdl;
+ return ConstOpAttr{attr_hdl};
}
namespace detail {
@@ -2897,6 +3076,136 @@ inline std::vector StringsToCharPtrs(const std::vector
}
} // namespace detail
+namespace detail {
+template
+inline size_t ConstNodeImpl::GetId() const {
+ size_t id;
+ ThrowOnError(GetApi().Node_GetId(this->p_, &id));
+ return id;
+}
+
+template
+inline std::string ConstNodeImpl::GetName() const {
+ const char* name;
+ ThrowOnError(GetApi().Node_GetName(this->p_, &name));
+ return std::string(name);
+}
+
+template
+inline std::string ConstNodeImpl::GetOperatorType() const {
+ const char* type;
+ ThrowOnError(GetApi().Node_GetOperatorType(this->p_, &type));
+ return std::string(type);
+}
+
+template
+inline std::string ConstNodeImpl::GetDomain() const {
+ const char* domain;
+ ThrowOnError(GetApi().Node_GetDomain(this->p_, &domain));
+ return std::string(domain);
+}
+
+template
+inline int ConstNodeImpl::GetSinceVersion() const {
+ int since_version;
+ ThrowOnError(GetApi().Node_GetSinceVersion(this->p_, &since_version));
+ return since_version;
+}
+
+template
+inline std::vector ConstNodeImpl::GetInputs() const {
+ static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo));
+ size_t num_vi;
+ ThrowOnError(GetApi().Node_GetNumInputs(this->p_, &num_vi));
+ std::vector result;
+ if (num_vi > 0) {
+ result.resize(num_vi);
+ ThrowOnError(GetApi().Node_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi));
+ }
+ return result;
+}
+
+template
+inline std::vector ConstNodeImpl::GetOutputs() const {
+ static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo));
+ size_t num_vi;
+ ThrowOnError(GetApi().Node_GetNumOutputs(this->p_, &num_vi));
+ std::vector result;
+ if (num_vi > 0) {
+ result.resize(num_vi);
+ ThrowOnError(GetApi().Node_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi));
+ }
+ return result;
+}
+
+template
+inline std::vector ConstNodeImpl::GetImplicitInputs() const {
+ static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo));
+ size_t num_vi;
+ ThrowOnError(GetApi().Node_GetNumImplicitInputs(this->p_, &num_vi));
+ std::vector result;
+ if (num_vi > 0) {
+ result.resize(num_vi);
+ ThrowOnError(GetApi().Node_GetImplicitInputs(this->p_, reinterpret_cast(result.data()),
+ num_vi));
+ }
+ return result;
+}
+
+template
+inline std::vector ConstNodeImpl::GetAttributes() const {
+ static_assert(sizeof(const OrtOpAttr*) == sizeof(ConstOpAttr), "Must be the same size");
+ size_t num_attrs;
+ ThrowOnError(GetApi().Node_GetNumAttributes(this->p_, &num_attrs));
+ std::vector attrs;
+ if (num_attrs > 0) {
+ attrs.resize(num_attrs);
+ ThrowOnError(GetApi().Node_GetAttributes(this->p_, reinterpret_cast(attrs.data()), num_attrs));
+ }
+ return attrs;
+}
+
+template
+inline Status ConstNodeImpl::GetAttributeByName(const std::string& name, ConstOpAttr& out) const {
+ const OrtOpAttr* attr = nullptr;
+ auto status = Status(GetApi().Node_GetAttributeByName(this->p_, name.c_str(), &attr));
+ out = ConstOpAttr{attr};
+ return status;
+}
+
+template
+inline std::vector