diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 2f1532d0643ae..3189b64349898 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -603,10 +603,6 @@ if(NOT (onnx_FOUND OR ONNX_FOUND)) # building ONNX from source endif() endif() -if (onnxruntime_RUN_ONNX_TESTS) - add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS) -endif() - if(onnxruntime_ENABLE_DLPACK) message(STATUS "dlpack is enabled.") diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 24cecf07e8e36..07e61fb210036 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -108,6 +108,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ) set(mlas_platform_preprocess_srcs @@ -429,12 +430,16 @@ else() ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ) if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() endif() set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs index c348184658e7e..bde39d9c6e6cc 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -4,6 +4,7 @@ namespace Microsoft.ML.OnnxRuntime { using System; + using System.Diagnostics; using System.Runtime.InteropServices; /// @@ -22,18 +23,19 @@ public enum OrtCompileApiFlags : uint /// This class is used to set options for model compilation, and to produce a compiled model using those options. /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. /// - public class OrtModelCompilationOptions : SafeHandle + public class OrtModelCompilationOptions : IDisposable { /// /// Create a new OrtModelCompilationOptions object from SessionOptions. /// + /// By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel() + /// to enable graph optimizations. /// SessionOptions instance to read settings from. public OrtModelCompilationOptions(SessionOptions sessionOptions) - : base(IntPtr.Zero, true) { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions( - OrtEnv.Instance().Handle, sessionOptions.Handle, out handle)); + OrtEnv.Instance().Handle, sessionOptions.Handle, out _handle)); } /// @@ -41,7 +43,7 @@ public OrtModelCompilationOptions(SessionOptions sessionOptions) /// public void CompileModel() { - NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle)); + NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, _handle)); } @@ -53,7 +55,7 @@ public void SetInputModelPath(string path) { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(_handle, platformPath)); } /// @@ -65,7 +67,7 @@ public void SetInputModelFromBuffer(byte[] buffer) { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer( - handle, buffer, (UIntPtr)buffer.Length)); + _handle, buffer, (UIntPtr)buffer.Length)); } /// @@ -76,7 +78,7 @@ public void SetOutputModelPath(string path) { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(_handle, platformPath)); } @@ -91,7 +93,7 @@ public void SetOutputModelExternalInitializersFile(string filePath, ulong thresh var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile( - handle, platformPath, new UIntPtr(threshold))); + _handle, platformPath, new UIntPtr(threshold))); } // TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure. @@ -106,7 +108,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator, { NativeApiStatus.VerifySuccess( NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer( - handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); + _handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); } /// @@ -117,7 +119,7 @@ internal void SetOutputModelBuffer(OrtAllocator allocator, public void SetEpContextEmbedMode(bool embed) { NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(_handle, embed)); } /// @@ -127,26 +129,379 @@ public void SetEpContextEmbedMode(bool embed) public void SetFlags(OrtCompileApiFlags flags) { NativeApiStatus.VerifySuccess( - NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags)); + NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(_handle, (uint)flags)); } - internal IntPtr Handle => handle; + /// + /// Sets information related to EP context binary file. The Ep uses this information to decide the + /// location and context binary file name when compiling with both the input and output models + /// stored in buffers. + /// + /// Path to the model directory. + /// The name of the model. + public void SetEpContextBinaryInformation(string outputDirectory, string modelName) + { + var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory); + var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation( + _handle, platformOutputDirectory, platformModelName)); + } + + /// + /// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified. + /// + /// The graph optimization level to set. + public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel( + _handle, graphOptimizationLevel)); + } + + /// + /// Delegate to write/save a buffer containing ONNX model bytes to a custom destination. The delegate + /// may be called repeatedly until the entire output model has been written out. Each call to the delegate + /// is expected to consume the entire buffer. + /// + /// The buffer to write out. + /// + public delegate void WriteBufferToDestinationDelegate(ReadOnlySpan buffer); + + /// + /// Sets a delegate that is called by ORT to write out the output model's serialized ONNX bytes. + /// The provided delegate may be called repeatedly until the entire output model has been written out. + /// Each call to the delegate is expected to consume/handle the entire input buffer. + /// + /// The delegate called by ORT to write out the model. + public void SetOutputModelWriteDelegate(WriteBufferToDestinationDelegate writeBufferDelegate) + { + _writeBufferToDestinationDelegateState?.Dispose(); + _writeBufferToDestinationDelegateState = + new DelegateResources( + new WriteBufferToDestinationConnector(writeBufferDelegate), + new NativeMethods.DOrtWriteBufferToDestinationDelegate( + WriteBufferToDestinationConnector.WriteBufferToDestinationDelegateWrapper)); + + IntPtr funcPtr = _writeBufferToDestinationDelegateState.GetFunctionPointerForDelegate(); + + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelWriteFunc( + _handle, + funcPtr, + _writeBufferToDestinationDelegateState.GetConnectorHandleAsPointer())); + } + + /// + /// Delegate called by ORT for every initializer when generating the compiled model. + /// The delegate allows the user to determine whether the initializer should be stored within the compiled + /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate + /// implementation is responsible for writing the initializer data to a file. + /// + /// The initializer's name. + /// The readonly OrtValue instance containing the data, type, and + /// shape of the initializer. + /// May be null. If the initializer is originally stored externally, + /// this contains the file path, file offset, and data size. Otherwise, this is null. + /// A new OrtExternalInitializerInfo indicating the new location of the initializer. + /// Returns null if the initializer should be stored within the generated compiled model. + /// The return value may be null. + /// + public delegate OrtExternalInitializerInfo GetInitializerLocationDelegate( + string initializerName, + IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation); + + /// + /// Sets a delegate that is called by ORT for every initializer when generating the compiled model. + /// The delegate allows the user to determine whether the initializer should be stored within the compiled + /// model or externally in a file. If the delegate chooses to store an initializer externally, the delegate + /// implementation is responsible for writing the initializer data to a file. + /// + /// The delegate called by ORT for every initializer. + public void SetOutputModelGetInitializerLocationDelegate( + GetInitializerLocationDelegate getInitializerLocationDelegate) + { + _getInitializerLocationDelegateState?.Dispose(); + _getInitializerLocationDelegateState = + new DelegateResources( + new GetInitializerLocationConnector(getInitializerLocationDelegate), + new NativeMethods.DOrtGetInitializerLocationDelegate( + GetInitializerLocationConnector.GetInitializerLocationDelegateWrapper)); + + IntPtr funcPtr = _getInitializerLocationDelegateState.GetFunctionPointerForDelegate(); + + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + _handle, + funcPtr, + _getInitializerLocationDelegateState.GetConnectorHandleAsPointer())); + } + + #region Delegate helpers + /// + /// Class to bridge the C# and native worlds for the "write buffer to destination" delegate + /// + private class WriteBufferToDestinationConnector + { + private readonly WriteBufferToDestinationDelegate _userDelegate; + + internal WriteBufferToDestinationConnector(WriteBufferToDestinationDelegate writeBufferDelegate) + { + _userDelegate = writeBufferDelegate; + } + + public static IntPtr WriteBufferToDestinationDelegateWrapper(IntPtr /* void* */ state, + IntPtr /* const void* */ buffer, + UIntPtr /* size_t */ bufferNumBytes) + { + try + { + + WriteBufferToDestinationConnector connector = (WriteBufferToDestinationConnector) + GCHandle.FromIntPtr(state).Target; + ReadOnlySpan bufferSpan; + + unsafe + { + // NOTE: A Span can only view 2GB of data. This is fine because ORT does not write out + // chunks that large. However, if we ever need to, the solution is to just write a loop here + // that repeatedly calls the delegate with smaller chunks of data. + bufferSpan = new ReadOnlySpan(buffer.ToPointer(), checked((int)bufferNumBytes)); + } + + connector._userDelegate(bufferSpan); + } + catch (Exception ex) + { + var error = $"The C# WriteBufferToDestination delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } + + /// + /// Class to bridge the C# and native worlds for the "get initializer location" delegate + /// + private class GetInitializerLocationConnector + { + private readonly GetInitializerLocationDelegate _userDelegate; + + internal GetInitializerLocationConnector(GetInitializerLocationDelegate getInitializerLocationDelegate) + { + _userDelegate = getInitializerLocationDelegate; + } + + public static IntPtr GetInitializerLocationDelegateWrapper( + IntPtr /* void* */ state, + IntPtr /* const char* */ initializerName, + IntPtr /* const OrtValue* */ initializerValue, + IntPtr /* const OrtExternalInitializerInfo* */ originalInitializerLocation, + out IntPtr /* OrtExternalInitializerInfo** */ newInitializerLocationOutput) + { + newInitializerLocationOutput = IntPtr.Zero; + + try + { + + GetInitializerLocationConnector connector = (GetInitializerLocationConnector)GCHandle. + FromIntPtr(state).Target; + string utf8InitializerName = NativeOnnxValueHelper.StringFromNativeUtf8(initializerName); + IReadOnlyOrtValue readOnlyInitializerValue = new OrtValue(initializerValue, owned: false); + IReadOnlyExternalInitializerInfo readOnlyOriginalInitializerLocation = null; + + if (originalInitializerLocation != IntPtr.Zero) + { + readOnlyOriginalInitializerLocation = new OrtExternalInitializerInfo( + originalInitializerLocation, ownsHandle: false); + } + // Call user's delegate, which may return the new location of the initializer. + OrtExternalInitializerInfo newInitializerLocation = connector._userDelegate( + utf8InitializerName, readOnlyInitializerValue, readOnlyOriginalInitializerLocation); + + if (newInitializerLocation != null) + { + // Delegate returned info about a new location for the initializer. + // Can't guarantee that the new external info returned by user's delegate is not referenced + // by other C# code. ORT expects to own the new external info, so create a copy here and + // give it to ORT. + string newFilePath = newInitializerLocation.GetFilePath(); + byte[] newFilePathBytes = NativeOnnxValueHelper.GetPlatformSerializedString(newFilePath); + + IntPtr status = NativeMethods.OrtCreateExternalInitializerInfo( + newFilePathBytes, + newInitializerLocation.GetFileOffset(), + (UIntPtr)newInitializerLocation.GetByteSize(), + out newInitializerLocationOutput); + + if (status != IntPtr.Zero) + { + return status; + } + } + else + { + // User's delegate did not return a new location for the initializer. ORT will store initializer + // within the generated compiled model. + newInitializerLocationOutput = IntPtr.Zero; + } + } + catch (Exception ex) + { + var error = $"The C# GetInitializerLocation delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } /// - /// Indicates whether the native handle is invalid. + /// Disposable class that stores resources for a delegate provided by the user. /// - public override bool IsInvalid => handle == IntPtr.Zero; + /// The type of the connector class + /// (e.g., WriteBufferToDestinationConnector) + /// The type of the native delegate. + private class DelegateResources : IDisposable + where Connector : class + where Delegate : class + { + public DelegateResources(Connector connector, Delegate @delegate) + { + _connector = connector; + _delegate = @delegate; + _connectorHandle = GCHandle.Alloc(_connector); + _delegateHandle = GCHandle.Alloc(_delegate); + } + internal IntPtr GetFunctionPointerForDelegate() + { + return Marshal.GetFunctionPointerForDelegate(_delegate); + } + + internal IntPtr GetConnectorHandleAsPointer() + { + return GCHandle.ToIntPtr(_connectorHandle); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + // Dispose other children disposables. We have none. + } + + if (_connectorHandle.IsAllocated) + { + _connectorHandle.Free(); + _connector = null; + } + + if (_delegateHandle.IsAllocated) + { + _delegateHandle.Free(); + _delegate = null; + } + + _disposed = true; + } + + ~DelegateResources() + { + Dispose(false); + } + + private Connector _connector = null; + private Delegate _delegate = null; + private GCHandle _connectorHandle = default; + private GCHandle _delegateHandle = default; + private bool _disposed = false; + } + #endregion + + #region IDispose implementation /// - /// Release the native instance of OrtModelCompilationOptions. + /// IDispose implementation. /// - /// true - protected override bool ReleaseHandle() + public void Dispose() { - NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle); - handle = IntPtr.Zero; - return true; + Dispose(true); + GC.SuppressFinalize(this); } + + /// + /// IDispose implementation + /// + /// True if Dispose() has been called by the user-side code. False if + /// called by the runtime from inside the finalizer. + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + _writeBufferToDestinationDelegateState?.Dispose(); + _getInitializerLocationDelegateState?.Dispose(); + } + + Debug.Assert(_handle != IntPtr.Zero); + NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(_handle); + _handle = IntPtr.Zero; + _disposed = true; + } + + /// + /// Finalizer that releases the native handle if not already released by Dispose(). + /// + ~OrtModelCompilationOptions() + { + Dispose(false); + } + #endregion + + /// + /// Handle to the native OrtModelCompilationOptions object. + /// + private IntPtr _handle; + + /// + /// True if this OrtModelCompilationOptions instance has already been disposed. + /// + private bool _disposed = false; + + /// + /// Stores delegate state for the "write buffer to destination" delegate. + /// + private DelegateResources + _writeBufferToDestinationDelegateState = null; + + /// + /// Stores delegate state for the "get initializer location" delegate. + /// + private DelegateResources + _getInitializerLocationDelegateState = null; } -} \ No newline at end of file +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 3edc25b307a21..84020d84c9e73 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -21,6 +21,10 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; public IntPtr CompileModel; public IntPtr ModelCompilationOptions_SetFlags; + public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation; + public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel; + public IntPtr ModelCompilationOptions_SetOutputModelWriteFunc; + public IntPtr ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; } internal class NativeMethods @@ -101,6 +105,37 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile uint flags); public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputDirectory, + byte[] /* const ORTCHAR_T* */ modelName); + public DOrtModelCompilationOptions_SetEpContextBinaryInformation + OrtModelCompilationOptions_SetEpContextBinaryInformation; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel( + IntPtr /* OrtModelCompilationOptions* */ options, + GraphOptimizationLevel graphOptimizationLevel); + public DOrtModelCompilationOptions_SetGraphOptimizationLevel + OrtModelCompilationOptions_SetGraphOptimizationLevel; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelWriteFunc( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* DOrtWriteBufferDelegate */ writeFunc, + IntPtr /* void* */ state); + public DOrtModelCompilationOptions_SetOutputModelWriteFunc + OrtModelCompilationOptions_SetOutputModelWriteFunc; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* DOrtHandleInitializerDataDelegate */ handleInitializerFunc, + IntPtr /* void* */ state); + public DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { @@ -161,6 +196,27 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi _compileApi.ModelCompilationOptions_SetFlags, typeof(DOrtModelCompilationOptions_SetFlags)); + OrtModelCompilationOptions_SetEpContextBinaryInformation = + (DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextBinaryInformation, + typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation)); + + OrtModelCompilationOptions_SetGraphOptimizationLevel = + (DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetGraphOptimizationLevel, + typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel)); + + OrtModelCompilationOptions_SetOutputModelWriteFunc = + (DOrtModelCompilationOptions_SetOutputModelWriteFunc)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelWriteFunc, + typeof(DOrtModelCompilationOptions_SetOutputModelWriteFunc)); + + OrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc = + (DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)Marshal. + GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + typeof(DOrtModelCompilationOptions_SetOutputModelGetInitializerLocationFunc)); + } } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 3c92400715740..53880308da261 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -450,6 +450,7 @@ public struct OrtApi public IntPtr Graph_GetModelMetadata; public IntPtr GetModelCompatibilityForEpDevices; + public IntPtr CreateExternalInitializerInfo; } internal static class NativeMethods @@ -787,9 +788,35 @@ static NativeMethods() api_.SessionOptionsSetEpSelectionPolicyDelegate, typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); + OrtReleaseExternalInitializerInfo = + (DOrtReleaseExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseExternalInitializerInfo, + typeof(DOrtReleaseExternalInitializerInfo)); + + OrtExternalInitializerInfo_GetFilePath = + (DOrtExternalInitializerInfo_GetFilePath)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetFilePath, + typeof(DOrtExternalInitializerInfo_GetFilePath)); + + OrtExternalInitializerInfo_GetFileOffset = + (DOrtExternalInitializerInfo_GetFileOffset)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetFileOffset, + typeof(DOrtExternalInitializerInfo_GetFileOffset)); + + OrtExternalInitializerInfo_GetByteSize = + (DOrtExternalInitializerInfo_GetByteSize)Marshal.GetDelegateForFunctionPointer( + api_.ExternalInitializerInfo_GetByteSize, + typeof(DOrtExternalInitializerInfo_GetByteSize)); + OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer( api_.GetModelCompatibilityForEpDevices, typeof(DOrtGetModelCompatibilityForEpDevices)); + + OrtCreateExternalInitializerInfo = + (DOrtCreateExternalInitializerInfo)Marshal.GetDelegateForFunctionPointer( + api_.CreateExternalInitializerInfo, + typeof(DOrtCreateExternalInitializerInfo)); + } internal class NativeLib @@ -2382,6 +2409,70 @@ out IntPtr lora_adapter public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi(); #endif public static DOrtGetCompileApi OrtGetCompileApi; + + /// + /// Delegate called by ORT to write a buffer (ONNX model bytes) to a custom destination (e.g., file or stream). + /// + /// State that was provided in when the delegate was registered. + /// The buffer to write. + /// The size of the buffer in bytes. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtWriteBufferToDestinationDelegate( + IntPtr /* void* */ state, + IntPtr /* const void* */ buffer, + UIntPtr /* size_t */ bufferNumBytes + ); + + /// + /// Function called by ORT to allow user to specify how an initializer should be saved while compiling + /// a model, that is, either written to an external file or stored within the model. ORT calls this function + /// for every initializer. + /// + /// State that was provided when the delegate was registered. + /// The initializer's name. + /// The OrtValue containing the initializer's data, type, and shape + /// The original initializer's location in an external file, or NULL. + /// Output parameter set to a new OrtExternalInitializerInfo instance + /// indicating the location where the function implementation stored the initializer data. If the function + /// implementation sets `newExternalInfo` to NULL, ORT stores the initializer within the generated model. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetInitializerLocationDelegate( + IntPtr /* void* */ state, + IntPtr /* const char* */ initializerName, + IntPtr /* const OrtValue* */ initializerValue, + IntPtr /* const OrtExternalInitializerInfo* */ externalInfo, + out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseExternalInitializerInfo(IntPtr /* OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateExternalInitializerInfo( + byte[] /* const ORTCHAR_T* */ filePath, + long /* int64_t */ fileOffset, + UIntPtr /* size_t */ byteSize, + out IntPtr /* OrtExternalInitializerInfo** */ outInfo); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const ORTCHAR_T* */ DOrtExternalInitializerInfo_GetFilePath( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate long /* int64_t */ DOrtExternalInitializerInfo_GetFileOffset( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate UIntPtr /* size_t */ DOrtExternalInitializerInfo_GetByteSize( + IntPtr /* const OrtExternalInitializerInfo* */ info); + + public static DOrtReleaseExternalInitializerInfo OrtReleaseExternalInitializerInfo; + public static DOrtCreateExternalInitializerInfo OrtCreateExternalInitializerInfo; + public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath; + public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset; + public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize; #endregion #region Auto EP API related diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index fc14be00ee47b..4611428ea12ef 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -150,6 +150,45 @@ internal static byte[] GetPlatformSerializedString(string str) else return StringToZeroTerminatedUtf8(str); } + + /// + /// Converts a null-terminated path string that is pointed to by the given IntPtr handle into + /// a C# UTF-16 string. + /// + /// A path string on Windows is utf-16, but utf-8 on other operating systems. + /// + /// + internal static string StringFromNativePathString(IntPtr strPtr) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if (strPtr == IntPtr.Zero) + { + return string.Empty; + } + + // Get length of utf16 string by checking for two 0 bytes in a row. + int length = 0; + while (Marshal.ReadInt16(strPtr, length * 2) != 0) + { + length += 1; + } + + if (length == 0) + { + return string.Empty; + } + + unsafe + { + return System.Text.Encoding.Unicode.GetString((byte*)strPtr, length * 2); + } + } + else + { + return StringFromNativeUtf8(strPtr); + } + } } // Guards an array of disposable objects on stack and disposes them in reverse order diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs new file mode 100644 index 0000000000000..aca16e939ce21 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtExternalInitializerInfo.shared.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Diagnostics; + using System.Runtime.InteropServices; + + /// + /// Class to that stores information about the file location where an "external" initializer is stored. + /// + /// + public class OrtExternalInitializerInfo : SafeHandle, IReadOnlyExternalInitializerInfo + { + // Set to false when constructed with an externally managed constant handle owned by ORT. + private readonly bool _ownsHandle = true; + + /// + /// Create a new OrtExternalInitializerInfo instance. + /// + /// The path to the file that stores the initializer data. + /// The byte offset in the file where the data is stored. + /// The size of the data (in bytes) within the file. + public OrtExternalInitializerInfo(string filePath, long fileOffset, long byteSize) + : base(IntPtr.Zero, ownsHandle: true) + { + var platformFilePath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateExternalInitializerInfo(platformFilePath, fileOffset, (UIntPtr)byteSize, out handle)); + _ownsHandle = true; + } + + /// + /// Create a new OrtExternalInitializerInfo instance from an existing native OrtExternalInitializerInfo handle. + /// + /// Native OrtExternalInitializerInfo handle. + /// True if the OrtExternalInitializerInfo instance owns the native handle. + /// Defaults to false. + internal OrtExternalInitializerInfo(IntPtr constHandle, bool ownsHandle = false) + : base(IntPtr.Zero, ownsHandle) + { + Debug.Assert(constHandle != IntPtr.Zero); + SetHandle(constHandle); + _ownsHandle = ownsHandle; + } + + /// + /// Get the file path to the file that store's the initializer's data. + /// + /// + /// The path is relative to the filesystem directory where the ONNX model was stored. + /// + /// The file path. + public string GetFilePath() + { + IntPtr filePathPtr = NativeMethods.OrtExternalInitializerInfo_GetFilePath(handle); + if (filePathPtr == IntPtr.Zero) + { + return string.Empty; + } + + return NativeOnnxValueHelper.StringFromNativePathString(filePathPtr); + } + + /// + /// Get the byte offset within the file where the initializer's data is stored. + /// + /// The file offset location. + public long GetFileOffset() + { + return NativeMethods.OrtExternalInitializerInfo_GetFileOffset(handle); + } + + /// + /// Get the size in bytes of the initializer's data within the file. + /// + /// The size in bytes of the initializer data. + public long GetByteSize() + { + UIntPtr byteSize = NativeMethods.OrtExternalInitializerInfo_GetByteSize(handle); + return checked((long)byteSize); + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Release the native instance of OrtExternalInitializerInfo if we own it. + /// + /// true on success and false on error. + protected override bool ReleaseHandle() + { + if (!_ownsHandle) + { + // Return false to indicate an error. + // ReleaseHandle() should not be called on a const handle that this class does not own. + return false; + } + + NativeMethods.OrtReleaseExternalInitializerInfo(handle); + handle = IntPtr.Zero; + return true; + } + } + + /// + /// Interface for all readonly methods implemented by OrtExternalInitializerInfo. + /// + public interface IReadOnlyExternalInitializerInfo + { + /// + /// Get the file path to the file that store's the initializer's data. + /// + /// + /// The path is relative to the filesystem directory where the ONNX model was stored. + /// + /// The file path. + string GetFilePath(); + + /// + /// Get the byte offset within the file where the initializer's data is stored. + /// + /// The file offset location. + long GetFileOffset(); + + /// + /// Get the size in bytes of the initializer's data within the file. + /// + /// The size in bytes of the initializer data. + long GetByteSize(); + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 01ee3aa5ae753..d848c63450ec1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -33,6 +33,147 @@ public enum OnnxValueType ONNX_TYPE_OPTIONAL = 6, // It's an optional type that designates anything above (except UNKNOWN) } + /// + /// Interface for all readonly methods implemented by OrtValue. + /// + public interface IReadOnlyOrtValue + { + /// + /// Get the ONNX value type for the OrtValue (e.g., OnnxValueType.ONNX_TYPE_TENSOR). + /// + /// OnnxValueType + OnnxValueType OnnxType { get; } + + /// + /// Returns true if OrtValue contains a tensor + /// + /// true if tensor + bool IsTensor { get; } + + /// + /// Returns true if OrtValue contains a sparse tensor + /// + /// true if sparse tensor + bool IsSparseTensor { get; } + + /// + /// Returns type information about the contained OnnxValue. + /// + /// a disposable instance of OrtTypeInfo + OrtTypeInfo GetTypeInfo(); + + /// + /// Obtains Tensor And Type Information from the OrtValue iff it contains a tensor. + /// Valid only for OrtValues that contain a tensor. + /// + /// A disposable instance of OrtTensorTypeAndShapeInfo + OrtTensorTypeAndShapeInfo GetTensorTypeAndShape(); + + /// + /// Returns the size of the tensor data in bytes. + /// + /// size of the tensor data in bytes + long GetTensorSizeInBytes(); + + /// + /// Returns OrtMemoryInfo iff this OrtValue contains a tensor or a sparse tensor. + /// + /// OrtMemoryInfo that describes the underlying memory allocation + /// + OrtMemoryInfo GetTensorMemoryInfo(); + + /// + /// Returns a ReadOnlySpan over tensor native buffer that + /// provides a read-only view. + /// + /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. + /// To get memory descriptor use GetTensorMemoryInfo(). + /// + /// OrtValue must contain a non-string tensor. + /// The span is valid as long as the OrtValue instance is alive (not disposed). + /// + /// + /// ReadOnlySpan + /// + ReadOnlySpan GetTensorDataAsSpan() where T : unmanaged; + +#if NET8_0_OR_GREATER + /// + /// Returns a ReadOnlyTensorSpan over tensor native buffer that + /// provides a read-only view. + /// + /// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. + /// To get memory descriptor use GetTensorMemoryInfo(). + /// + /// OrtValue must contain a non-string tensor. + /// The span is valid as long as the OrtValue instance is alive (not disposed). + /// + /// + /// ReadOnlySpan + /// + [Experimental("SYSLIB5001")] + SystemNumericsTensors.ReadOnlyTensorSpan GetTensorDataAsTensorSpan() where T : unmanaged; +#endif + + /// + /// Valid for composite ML types like map, sequence. + /// Returns 2 for map (keys, values) and N for sequence, where N is the number of elements + /// in the sequence. + /// + /// Element count + int GetValueCount(); + + /// + /// For non tensors return OrtValue element at the specified index. + /// For maps only indices 0 and 1 are valid. For sequences, [0..N) are valid. + /// See GetValueCount() to determine the valid range. + /// + /// + /// allocator to use + /// OrtValue disposable instance that points to the corresponding element of the composite type + OrtValue GetValue(int index, OrtAllocator allocator); + + /// + /// Fetch string tensor element buffer pointer at the specified index, + /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance. + /// + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat string tensor element index + /// ReadOnlyMemory{char} backed by a managed char[]. Its lifespan is not + /// tied to the native buffer of OrtValue. + ReadOnlyMemory GetStringElementAsMemory(int index); + + /// + /// Fetch string tensor element buffer pointer at the specified index, + /// copy/convert UTF-8 into a UTF-16 string and return it. + /// + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat string tensor element index + /// UTF-16 string instance + string GetStringElement(int index); + + /// + /// Get a span over the native memory of the string tensor element. + /// The span is valid as long as the OrtValue is valid. + /// + /// This is useful if you want to perform your own UTF-8 decoding or + /// you do not care about decoding. + /// Obtain TensorTypeAndShape to get shape and element count. + /// + /// flat element index + /// ReadOnlySpan over UTF-8 bytes of the string tensor element + ReadOnlySpan GetStringElementAsSpan(int index); + + /// + /// Convenience method to obtain all string tensor elements as a string array. + /// + /// string[] + /// + string[] GetStringTensorAsArray(); + } + /// /// Represents a disposable OrtValue. /// This class exposes a native instance of OrtValue. @@ -44,7 +185,7 @@ public enum OnnxValueType /// disposed properly, the pinned memory will continue to be pinned and interfere /// with GC operation. /// - public class OrtValue : IOrtValueOwner, IDisposable + public class OrtValue : IOrtValueOwner, IDisposable, IReadOnlyOrtValue { // OrtValues that are members of Sequences or Maps that map. They potentially map managed memory and we need to keep them around. // this exists only when we deal with compose ML types. @@ -52,11 +193,20 @@ public class OrtValue : IOrtValueOwner, IDisposable private IntPtr _handle; private MemoryHandle? _memHandle; // Present when the OrtValue is created on top of managed memory private bool _disposed; + private bool _owned = true; - internal OrtValue(IntPtr handle) + /// + /// Constructs OrtValue from a native handle. If `owned` is true, the OrtValue instance takes + /// ownership of the native handle and disposes it when the OrtValue instance is disposed. + /// + /// The native OrtValue handle. + /// True if this class instance owns the handle. If false, the handle + /// will not be released. Defaults to true. + internal OrtValue(IntPtr handle, bool owned = true) { _handle = handle; InitOnnxType(); + _owned = owned; } /// @@ -1464,7 +1614,10 @@ protected virtual void Dispose(bool disposing) } Debug.Assert(_handle != IntPtr.Zero); - NativeMethods.OrtReleaseValue(_handle); + if (_owned) + { + NativeMethods.OrtReleaseValue(_handle); + } _handle = IntPtr.Zero; _disposed = true; } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs index bf576b54d8b45..fe2cab57658c8 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -21,102 +21,249 @@ public class CompileApiTests [Fact] public void BasicUsage() { - var so = new SessionOptions(); - using (var compileOptions = new OrtModelCompilationOptions(so)) + using (var sessionOptions = new SessionOptions()) { - // mainly checking these don't throw which ensures all the plumbing for the binding works. - compileOptions.SetInputModelPath("model.onnx"); - compileOptions.SetOutputModelPath("compiled_model.onnx"); + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.SetInputModelPath("model.onnx"); + compileOptions.SetOutputModelPath("compiled_model.onnx"); - compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); - compileOptions.SetEpContextEmbedMode(true); + compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); + compileOptions.SetEpContextEmbedMode(true); + compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC); - } + } - // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer - using (var compileOptions = new OrtModelCompilationOptions(so)) - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - compileOptions.SetInputModelFromBuffer(model); + // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + compileOptions.SetInputModelFromBuffer(model); - // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. - // Due to that we need to allocate an IntPtr and UIntPtr here. - IntPtr bytePtr = new IntPtr(); - UIntPtr bytesSize = new UIntPtr(); - var allocator = OrtAllocator.DefaultInstance; - compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. + // Due to that we need to allocate an IntPtr and UIntPtr here. + IntPtr bytePtr = new IntPtr(); + UIntPtr bytesSize = new UIntPtr(); + var allocator = OrtAllocator.DefaultInstance; + compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx"); - compileOptions.CompileModel(); + compileOptions.CompileModel(); - Assert.NotEqual(IntPtr.Zero, bytePtr); - Assert.NotEqual(UIntPtr.Zero, bytesSize); + Assert.NotEqual(IntPtr.Zero, bytePtr); + Assert.NotEqual(UIntPtr.Zero, bytesSize); - byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; - Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); + byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; + Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); - // Check the compiled model is valid - using (var session = new InferenceSession(compiledBytes, so)) - { - Assert.NotNull(session); + // Check the compiled model is valid + using (var session = new InferenceSession(compiledBytes, sessionOptions)) + { + Assert.NotNull(session); + } + + allocator.FreeMemory(bytePtr); } - allocator.FreeMemory(bytePtr); - } + // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate + // any compiled EPContext nodes, so expect an ORT_FAIL error. + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var output_model_file = "should_not_generate.onnx"; + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); - // Test using OrtCompileApiFlags.ERROR_NO_NODES_COMPILED. A model compiled with CPU EP will not generate - // any compiled EPContext nodes, so expect an ORT_FAIL error. - using (var compileOptions = new OrtModelCompilationOptions(so)) - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - var output_model_file = "should_not_generate.onnx"; - compileOptions.SetInputModelFromBuffer(model); - compileOptions.SetOutputModelPath(output_model_file); - compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_NO_NODES_COMPILED); + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("Unable to compile any nodes", ex.Message); + } - // compile should fail + Assert.False(File.Exists(output_model_file)); // Output file should not be generated. + } + + // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. + var outputModelFile = "squeezenet_ctx.onnx"; try { - compileOptions.CompileModel(); - Assert.Fail("CompileModel() should have thrown an exception"); + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFile); + compileOptions.CompileModel(); + Assert.True(File.Exists(outputModelFile)); + + // Try to compile again with flag that prevents replacing an existing file. + // Expect failure. + compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + + // compile should fail + try + { + compileOptions.CompileModel(); + Assert.Fail("CompileModel() should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + Assert.Contains("exists already", ex.Message); + } + } } - catch (OnnxRuntimeException ex) + finally { - Assert.Contains("Unable to compile any nodes", ex.Message); + if (File.Exists(outputModelFile)) + { + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFile); + } } - - Assert.False(File.Exists(output_model_file)); // Output file should not be generated. } + } + + [Fact] + public void WriteOutModelWithDelegate() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var outputModelFilePath = "squeezenet_write_delegate_ctx.onnx"; - // Test using OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS. - using (var compileOptions = new OrtModelCompilationOptions(so)) + using (FileStream fs = new FileStream(outputModelFilePath, FileMode.Create, FileAccess.Write, FileShare.None, + 4096, FileOptions.DeleteOnClose)) + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) { - var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); - var output_model_file = "squeezenet_ctx.onnx"; + void BasicWriteBufferDelegate(ReadOnlySpan buffer) + { + Assert.True(buffer.Length > 0); + fs.Write(buffer.ToArray(), 0, buffer.Length); // Write it out to a file + } // Compile and generate an output model. compileOptions.SetInputModelFromBuffer(model); - compileOptions.SetOutputModelPath(output_model_file); + compileOptions.SetOutputModelWriteDelegate(BasicWriteBufferDelegate); compileOptions.CompileModel(); - Assert.True(File.Exists(output_model_file)); + Assert.True(File.Exists(outputModelFilePath)); + } + } - // Try to compile again with flag that prevents replacing an existing file. - // Expect failure. - compileOptions.SetFlags(OrtCompileApiFlags.ERROR_IF_OUTPUT_FILE_EXISTS); + [Fact] + public void BasicGetInitializerLocationDelegate() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + var outputModelFilePath = "squeezenet_handle_initializer_delegate_ctx.onnx"; + var initializersFilePath = "squeezenet_handle_initializer_delegate_ctx.bin"; - // compile should fail - try + try + { + using (FileStream fs = new FileStream(initializersFilePath, FileMode.Create, FileAccess.Write, + FileShare.None, 4096, FileOptions.DeleteOnClose)) + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) { + // Custom delegate that stores large initializers in a new file. + OrtExternalInitializerInfo BasicHandleInitializer( + string initializerName, IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation) + { + Assert.True(initializerName.Length > 0); + + var byteSize = initializerValue.GetTensorSizeInBytes(); + if (byteSize <= 64) + { + // Keep small initializers stored within model. + return null; + } + + long byteOffset = fs.Position; + ReadOnlySpan dataSpan = initializerValue.GetTensorDataAsSpan(); + fs.Write(dataSpan.ToArray(), 0, dataSpan.Length); // Write it out to a file + + // Return the data's new location. + return new OrtExternalInitializerInfo(initializersFilePath, byteOffset, byteSize); + } + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFilePath); + compileOptions.SetOutputModelGetInitializerLocationDelegate(BasicHandleInitializer); compileOptions.CompileModel(); - Assert.Fail("CompileModel() should have thrown an exception"); + Assert.True(File.Exists(outputModelFilePath)); } - catch (OnnxRuntimeException ex) + } + finally + { + if (File.Exists(outputModelFilePath)) { - Assert.Contains("exists already", ex.Message); + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFilePath); } + } + } + + [Fact] + public void GetInitializerLocationDelegateThatReusesExternalInitializers() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("conv_qdq_external_ini.onnx"); + var outputModelFilePath = "conv_qdq_external_ini.reuse.ctx.onnx"; + bool reusedExternalInitializers = false; + + try + { + using (var sessionOptions = new SessionOptions()) + using (var compileOptions = new OrtModelCompilationOptions(sessionOptions)) + { + // Custom delegate that reuses the original external initializer file. + OrtExternalInitializerInfo ReuseExternalInitializers( + string initializerName, IReadOnlyOrtValue initializerValue, + IReadOnlyExternalInitializerInfo originalInitializerLocation) + { + Assert.True(initializerName.Length > 0); + + if (originalInitializerLocation != null) + { + reusedExternalInitializers = true; // For test assertion only + string originalFilePath = originalInitializerLocation.GetFilePath(); + long originalFileOffset = originalInitializerLocation.GetFileOffset(); + long originalByteSize = originalInitializerLocation.GetByteSize(); + + Assert.True(originalFilePath.Length > 0); + Assert.True(originalFileOffset >= 0); + Assert.True(originalByteSize > 0); - if (File.Exists(output_model_file)) + // This initializer comes from an external file. Reuse it for compiled model. + return new OrtExternalInitializerInfo(originalFilePath, originalFileOffset, originalByteSize); + } + + // Otherwise, embed initializers that were not originally external. + return null; + } + + // Compile and generate an output model. + compileOptions.SetInputModelFromBuffer(model); + compileOptions.SetOutputModelPath(outputModelFilePath); + compileOptions.SetOutputModelGetInitializerLocationDelegate(ReuseExternalInitializers); + compileOptions.CompileModel(); + + Assert.True(File.Exists(outputModelFilePath)); + Assert.True(reusedExternalInitializers); + } + } + finally + { + if (File.Exists(outputModelFilePath)) { - File.Delete(output_model_file); + // This file is created by ORT, so we delete it manually in finally block. + File.Delete(outputModelFilePath); } } } diff --git a/csharp/testdata/conv_qdq_external_ini.bin b/csharp/testdata/conv_qdq_external_ini.bin new file mode 100644 index 0000000000000..89eea0dba1fa4 Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.bin differ diff --git a/csharp/testdata/conv_qdq_external_ini.onnx b/csharp/testdata/conv_qdq_external_ini.onnx new file mode 100644 index 0000000000000..c53e1f3ad4d9b Binary files /dev/null and b/csharp/testdata/conv_qdq_external_ini.onnx differ diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..cbfc38068ac2a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3079,6 +3079,17 @@ This version of the operator has been available since version 1 of the 'com.micr Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). + + The SwiGLU (Swish-Gated Linear Unit) activation function is like: + g = xW + b + l = xV + c + G = clamp(g, max=limit) + L = clamp(l, min=-limit, max=limit) + swiglu = G * sigmoid(alpha * G) * (L + beta) + where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters. + When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs. + When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size. + When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row. #### Version @@ -3088,12 +3099,20 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
activation_alpha : float
+
Alpha parameter used in activation function.
+
activation_beta : float
+
Beta parameter used in activation function.
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
Whether to normalize routing weights
+
swiglu_fusion : int
+
0: not fused, 1: fused and interleaved. 2: fused and not interleaved.
+
swiglu_limit : float
+
The limit used to clamp in SwiGLU. No clamp when limit is not provided.
use_sparse_mixer : int
Whether to use sparse mixer
@@ -3106,15 +3125,15 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
-
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T
-
3D input tensor with shape (num_experts, inter_size, hidden_size)
+
3D input tensor with shape (num_experts, hidden_size, inter_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T
-
3D optional input tensor with shape (num_experts, hidden_size, inter_size)
+
3D optional input tensor with shape (num_experts, inter_size, hidden_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -3129,8 +3148,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
@@ -4522,14 +4541,22 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
activation_alpha : float
+
Alpha parameter used in activation function.
+
activation_beta : float
+
Beta parameter used in activation function.
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
expert_weight_bits : int
Number of bits used in quantized weights. Default is 4 bits
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
Whether to normalize routing weights
+
swiglu_fusion : int
+
0: not fused, 1: fused and interleaved. 2: fused and not interleaved.
+
swiglu_limit : float
+
The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.
use_sparse_mixer : int
Whether to use sparse mixer
@@ -4542,20 +4569,20 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc1_scales : T
-
2D input tensor with shape (num_experts, inter_size)
+
3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.
+
fc1_scales : T2
+
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
-
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
-
fc2_scales : T
+
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits
+
fc2_scales : T2
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
-
3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc3_scales (optional) : T
+
3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
+
fc3_scales (optional) : T2
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -4571,10 +4598,12 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
T1 : tensor(uint8)
Constrain weights type to uint8 tensors.
+
T2 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain scales type to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b70e5da8b3e4..660c63d056335 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -562,6 +562,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearWhere|*in* condition:**B**
*in* X:**T**
*in* x_scale:**TF**
*in* x_zero_point:**T**
*in* Y:**T**
*in* y_scale:**TF**
*in* y_zero_point:**T**
*in* z_scale:**TF**
*in* z_zero_point:**T**
*out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(float), tensor(float16)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| @@ -937,6 +938,7 @@ Do not modify directly.* |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(bfloat16), tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GemmaRotaryEmbedding|*in* emb:**U**
*in* q:**T**
*in* q_rot:**T**
*in* k:**T**
*in* k_rot:**T**
*out* output1:**T**
*out* output2:**T**|1+|**T** = tensor(float16)
**U** = tensor(float)| @@ -949,7 +951,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)| -|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -957,7 +959,7 @@ Do not modify directly.* |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 375f0a4dc8dd2..e59a803d97629 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -305,6 +305,24 @@ using BuildKernelCreateInfoFn = KernelCreateInfo (*)(); static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ } +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name) \ + provider##_##name##_##domain##_ver##ver##_##type1##_##type2##_##type3 + +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \ + class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name) \ + .SetDomain(domain) \ + .SinceVersion(ver) \ + .Provider(provider) \ + .Build(), \ + static_cast([](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ + } + #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \ provider##_##name##_##domain##_ver##startver##_##endver##_##type diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 866892979b749..9a0708d72b4f8 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1247,6 +1247,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const std::filesystem::path& model_file_path, const ModelSavingOptions& model_saving_options) const; + /// + /// Serialize the Graph to a onnx::GraphProto. Caller provides a function that determines where each initializer + /// is stored (i.e., either in an external file or within the model). + /// + /// Function called for every initializer. + /// Opaque user state passed to the handle_initializer_func. + /// Output parameter set to the serialized onnx::GraphProto. + /// A status indicating success or an error. + common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const; + /** Gets the ISchemaRegistry instances being used with this Graph. */ IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const; @@ -1664,6 +1676,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::ostream& external_stream, int64_t& external_offset) const; + Status ToGraphProtoWithCustomInitializerHandlingImpl(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const; #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index a32f465e44adf..026fc3b2dc0a0 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -34,6 +34,7 @@ constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "enable_cuda_graph"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; +constexpr const char* kRuntimeCacheFile = "nv_runtime_cache_path"; } // namespace provider_option_names namespace run_option_names { diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 28ce4439fdc7e..e2b2aff2011fe 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -203,415 +203,331 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, #define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ + Ort::Status _status{(fn)}; \ + if (!_status.IsOK()) { \ + return _status; \ } \ } while (0) #define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) + ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ +#define ORT_EP_UTILS_C_RETURN_IF(cond, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{msg, ORT_FAIL}; \ + } \ } while (0) namespace OrtEpUtils { -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::GraphProto& graph_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; + try { + Ort::ConstGraph ort_graph{&graph}; + // + // Set GraphProto metadata + // + auto graph_name = ort_graph.GetName(); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + std::vector graph_inputs = ort_graph.GetInputs(); + std::vector graph_outputs = ort_graph.GetOutputs(); + + for (const auto& ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + for (const auto& ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&value_infos, + &initializer_value_infos](Ort::ConstValueInfo ort_value_info, + /*out*/ std::optional& value_name_out) { + auto value_name = ort_value_info.GetName(); + + if (value_name_out) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return; // Already processed this OrtValueInfo. + } - return Ort::Status{nullptr}; - }; - - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + bool is_required_graph_input = ort_value_info.IsRequiredGraphInput(); + bool is_optional_graph_input = ort_value_info.IsOptionalGraphInput(); + bool is_graph_output = ort_value_info.IsGraphOutput(); + bool is_constant_initializer = ort_value_info.IsConstantInitializer(); + bool is_from_outer_scope = ort_value_info.IsFromOuterScope(); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, ort_value_info); + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, ort_value_info); // This is an internal OrtValueInfo. + } + }; + + std::vector nodes = ort_graph.GetNodes(); + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (const auto& ort_node : nodes) { + onnx::NodeProto* node_proto = graph_proto.add_node(); + + std::string node_name = ort_node.GetName(); + std::string node_domain = ort_node.GetDomain(); + std::string node_op_type = ort_node.GetOperatorType(); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + // Handle node attributes + std::vector ort_attrs = ort_node.GetAttributes(); + for (const auto& attr : ort_attrs) { + OrtOpAttrType attr_type = attr.GetType(); if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } - if (!attr_type_status.IsOK()) { - // Unsupported attribute type. - return attr_type_status; - } - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(attr, *attr_proto)); } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + // Handle node subgraphs + std::vector ort_subgraphs = ort_node.GetSubgraphs(); + for (const auto& [subgraph_attr_name, ort_subgraph] : ort_subgraphs) { onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - attr_proto->set_name(subgraph_attr_name); attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { + // Handle node inputs + std::vector ort_inputs = ort_node.GetInputs(); + for (const auto& vi : ort_inputs) { + if (vi == nullptr) { // missing optional input. node_proto->add_input(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_input(*value_name); } - } - - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + // Handle implicit inputs to this node. + std::vector ort_implicit_inputs = ort_node.GetImplicitInputs(); + for (const auto& vi : ort_implicit_inputs) { + assert(vi != nullptr); + std::optional value_name; + collect_value_info(vi, value_name); } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { + // Handle node outputs + std::vector ort_outputs = ort_node.GetOutputs(); + for (const auto& vi : ort_outputs) { + if (vi == nullptr) { // missing optional output. node_proto->add_output(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_output(*value_name); } } - } - - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const auto& [value_name, value_info] : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto)); } - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + // Add initializers to GraphProto as TensorProto objects. + for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) { + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + Ort::ConstValue ort_value{nullptr}; + ORT_EP_UTILS_C_RETURN_IF_ERROR(initializer_value_info.GetInitializer(ort_value)); - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; + assert(ort_value.IsTensor()); + const void* data = ort_value.GetTensorRawData(); + const size_t data_bytes = ort_value.GetTensorSizeInBytes(); - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - length_entry->set_key("length"); - length_entry->set_value(std::to_string(data_bytes)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; } -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::ModelProto& model_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } + try { + // Check that OrtGraph is a top-level graph (no parent node). + Ort::ConstGraph ort_graph{&graph}; + Ort::ConstNode parent_node = ort_graph.GetParentNode(); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = ort_graph.GetOnnxIRVersion(); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + std::vector op_sets = ort_graph.GetOperatorSets(); + ORT_EP_UTILS_C_RETURN_IF(op_sets.empty(), "OrtGraph should have at least one operator set."); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (const auto& op_set : op_sets) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(op_set.domain); + operator_set->set_version(op_set.version); + } - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_graph, *graph_proto, handle_initializer_data_func)); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + } catch (const Ort::Exception& ex) { + return Ort::Status(ex); + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_EP_FAIL); + } return Ort::Status{nullptr}; } -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); - - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + try { + Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); + ONNXType ort_onnx_type = ort_type_info.GetONNXType(); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor"); - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); - - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType ort_elem_type = ort_type_shape.GetElementType(); - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + size_t num_dims = ort_type_shape.GetDimensionsCount(); + std::vector ort_dims = ort_type_shape.GetShape(); - elem_type = ort_elem_type; - dims = std::move(ort_dims); + elem_type = ort_elem_type; + dims = std::move(ort_dims); - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_EP_FAIL}; } - return Ort::Status{nullptr}; } // Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - std::vector ort_dims; std::vector ort_dim_syms; ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; @@ -620,9 +536,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, ort_elem_type, ort_dims, ort_dim_syms)); - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); + value_info_proto.set_name(ort_value_info.GetName()); onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); @@ -652,213 +566,149 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr attr, onnx::AttributeProto& attr_proto) { + try { + std::string attr_name = attr.GetName(); + attr_proto.set_name(attr_name); - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + OrtOpAttrType attr_type = attr.GetType(); - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); - - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + int64_t i_val = 0; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(i_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto.set_i(i_val); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + std::vector i_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(i_vals)); + auto* ints = attr_proto.mutable_ints(); + ints->Assign(i_vals.begin(), i_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* strs = attr_proto.mutable_strings(); - - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); - - while (at < end) { - char* str_begin = at; - - while (*at && at < end) { - at++; - } - - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. - } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + float f_val = 0.0f; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(f_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto.set_f(f_val); + break; } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + std::vector f_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(f_vals)); + auto* floats = attr_proto.mutable_floats(); + floats->Assign(f_vals.begin(), f_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + std::string* str = attr_proto.mutable_s(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(*str)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + std::vector result; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(result)); + auto* strs = attr_proto.mutable_strings(); + strs->Assign(result.begin(), result.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + Ort::Value tensor; + ORT_EP_UTILS_C_RETURN_IF_ERROR(attr.GetTensorAttributeAsOrtValue(tensor)); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); - - onnx::TensorProto tensor_proto; - - // TensorProto as an attribute value doesn't require a name. - - OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); + auto shape = type_shape_info.GetShape(); - Ort::Value tensor(ort_value); + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } - // Get tensor type and shape info - Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + const void* data = tensor.GetTensorRawData(); + const size_t data_bytes = tensor.GetTensorSizeInBytes(); - // Get tensor type - ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); - size_t element_size = 0; - switch (element_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); - element_size = sizeof(float); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); - element_size = sizeof(uint8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); - element_size = sizeof(int8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); - element_size = sizeof(uint16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); - element_size = sizeof(int16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); - element_size = sizeof(int32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); - element_size = sizeof(int64_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); - element_size = sizeof(bool); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); - element_size = sizeof(double); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); - element_size = sizeof(uint32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); - element_size = sizeof(uint64_t); - break; - } - default: { - std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; } - - auto shape = type_shape_info.GetShape(); - - for (auto& dim : shape) { - tensor_proto.add_dims(dim); + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); } - - size_t element_count = type_shape_info.GetElementCount(); - size_t data_bytes = element_count * element_size; - const void* data = tensor.GetTensorData(); - - // Copy the Ortvalue to TensorProto as raw data - tensor_proto.set_raw_data(data, data_bytes); - - *(attr_proto.mutable_t()) = std::move(tensor_proto); - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 72e8a3ca1103c..8561de9c8c3b9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -533,6 +533,57 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); +/** \brief Function called by ORT to write a buffer to a custom destination (e.g., file, stream, etc.). + * + * \param state Opaque pointer holding the user's state. + * \param buffer The buffer to write. + * \param buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtWriteBufferFunc)(_In_ void* state, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + +/** \brief Function called by ORT to allow user to specify how an initializer should be saved, that is, either + * written to an external file or stored within the model. ORT calls this function for every initializer when + * generating a model. + * + * If the function implementation sets the `new_external_info` output parameter to NULL, ORT stores the initializer data + * within the generated model. + * + * Otherwise, if the function implementation sets `new_external_info` to a valid OrtExternalInitializerInfo instance, + * ORT assumes that this function stores the initializer data in a file. In this case, ORT configures the model's + * initializer to point to the location specified by the `new_external_info` output parameter. + * + * \param[in] state Opaque pointer holding the user's state. + * \param[in] initializer_name The initializer's name as a null-terminated string. + * \param[in] initializer_value OrtValue containing the initializer's data, type, and shape. + * \param[in] external_info If the initializer is originally stored in an external file, `external_info` contains + * the file path, file offset, and the data's byte size within the file. Otherwise, + * `external_info` is NULL if the initializer is not originally stored in a file. + * \param[out] new_external_info Output parameter set to a new OrtExternalInitializerInfo instance indicating the + * location where the function implementation stored the initializer data. + * The function implementation must use `OrtApi::CreateExternalInitializerInfo()` to + * create the instance. + * If the function implementation sets `new_external_info` to NULL, + * ORT stores the initializers within the model. + * + * \note ORT takes ownership of the `new_external_info` output parameter. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtGetInitializerLocationFunc)( + _In_ void* state, + _In_ const char* initializer_name, + _In_ const OrtValue* initializer_value, + _In_opt_ const OrtExternalInitializerInfo* external_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** new_external_info); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -6507,6 +6558,26 @@ struct OrtApi { _In_ size_t num_ep_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* out_status); + + /// \name OrtExternalInitializerInfo + /// @{ + + /** \brief Creates an OrtExternalInitializerInfo instance. + * + * \param[in] filepath The relative path to the file that stores the initializer's data. ORT copies this path string. + * \param[in] file_offset The byte offset where the initializer's data is stored within the file. + * \param[in] byte_size The size in bytes of the initializer's data within the file. + * \param[out] out Output parameter set to the new OrtExternalInitializerInfo instance. + * Must be released by calling ReleaseExternalInitializerInfo(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset, + _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); + + /// @} }; /* @@ -7074,6 +7145,9 @@ struct OrtCompileApi { * ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling * CompileModel. * + * \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use + * ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations. + * * \param[in] env OrtEnv object. * \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions. * \param[out] out The created OrtModelCompilationOptions instance. @@ -7230,7 +7304,7 @@ struct OrtCompileApi { * \since Version 1.23. */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, - size_t flags); + uint32_t flags); /** Sets information related to EP context binary file. * @@ -7249,6 +7323,56 @@ struct OrtCompileApi { _In_ OrtModelCompilationOptions* model_compile_options, _In_ const ORTCHAR_T* output_directory, _In_ const ORTCHAR_T* model_name); + + /** Set the graph optimization level. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] graph_optimization_level The graph optimization level. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level); + + /** \brief Sets a OrtWriteBufferFunc function that is called by ORT to write out the output model's serialized + * ONNX bytes. + * + * The provided write function may be called repeatedly until then entire output model has been written out. Each call + * to the write function is expected to consume the entire input buffer. + * + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_func The OrtWriteBufferFunc function called by ORT when writing out the model. + * \param[in] state Opaque state passed as the first argument to OrtWriteBufferFunc. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state); + + /** \brief Sets a OrtGetInitializerLocationFunc function that is called by ORT for every initializer in the generated + * model. Allows implementer to specify whether initializers should be stored within the model or externally. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] get_initializer_location_func The OrtGetInitializerLocationFunc function called by ORT when + * to determine the location of the initializer. + * \param[in] state Opaque state passed as the first argument to OrtGetInitializerLocationFunc. Can be NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 981c70ab8b954..9fa7915679f62 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -52,6 +52,7 @@ namespace Ort { * If ORT_NO_EXCEPTIONS is defined, then any error will result in a call to abort() */ struct Exception : std::exception { + Exception(const std::string& string, OrtErrorCode code) : message_{string}, code_{code} {} Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} OrtErrorCode GetOrtErrorCode() const { return code_; } @@ -549,33 +550,35 @@ namespace detail { inline void OrtRelease(Ort##NAME* ptr) { API_GETTER().Release##NAME(ptr); } ORT_DEFINE_RELEASE(Allocator); -ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(CustomOpDomain); -ORT_DEFINE_RELEASE(ThreadingOptions); ORT_DEFINE_RELEASE(Env); -ORT_DEFINE_RELEASE(RunOptions); +ORT_DEFINE_RELEASE(ExternalInitializerInfo); +ORT_DEFINE_RELEASE(Graph); +ORT_DEFINE_RELEASE(IoBinding); +ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(KeyValuePairs); ORT_DEFINE_RELEASE(LoraAdapter); +ORT_DEFINE_RELEASE(MemoryInfo); +ORT_DEFINE_RELEASE(MapTypeInfo); +ORT_DEFINE_RELEASE(Model); +ORT_DEFINE_RELEASE(ModelMetadata); +ORT_DEFINE_RELEASE(Node); +ORT_DEFINE_RELEASE(Op); +ORT_DEFINE_RELEASE(OpAttr); +ORT_DEFINE_RELEASE(PrepackedWeightsContainer); +ORT_DEFINE_RELEASE(RunOptions); ORT_DEFINE_RELEASE(Session); ORT_DEFINE_RELEASE(SessionOptions); -ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); ORT_DEFINE_RELEASE(SequenceTypeInfo); -ORT_DEFINE_RELEASE(MapTypeInfo); -ORT_DEFINE_RELEASE(TypeInfo); -ORT_DEFINE_RELEASE(Value); -ORT_DEFINE_RELEASE(ModelMetadata); -ORT_DEFINE_RELEASE(IoBinding); -ORT_DEFINE_RELEASE(ArenaCfg); ORT_DEFINE_RELEASE(Status); ORT_DEFINE_RELEASE(SyncStream); -ORT_DEFINE_RELEASE(OpAttr); -ORT_DEFINE_RELEASE(Op); -ORT_DEFINE_RELEASE(KernelInfo); +ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); +ORT_DEFINE_RELEASE(ThreadingOptions); +ORT_DEFINE_RELEASE(TypeInfo); +ORT_DEFINE_RELEASE(Value); ORT_DEFINE_RELEASE(ValueInfo); -ORT_DEFINE_RELEASE(Node); -ORT_DEFINE_RELEASE(Graph); -ORT_DEFINE_RELEASE(Model); -ORT_DEFINE_RELEASE(KeyValuePairs); -ORT_DEFINE_RELEASE(PrepackedWeightsContainer); + ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); @@ -701,6 +704,7 @@ struct AllocatedFree { struct AllocatorWithDefaultOptions; struct Env; struct EpDevice; +struct ExternalInitializerInfo; struct Graph; struct Model; struct Node; @@ -818,6 +822,46 @@ struct PrepackedWeightsContainer : detail::Base { PrepackedWeightsContainer(); }; +namespace detail { +template +struct ConstExternalInitializerInfoImpl : Base { + using B = Base; + using B::B; + + // Wraps OrtApi::ExternalInitializerInfo_GetFilePath + const std::basic_string GetFilePath() const; + // Wraps OrtApi::ExternalInitializerInfo_GetFileOffset + int64_t GetFileOffset() const; + // Wraps OrtApi::ExternalInitializerInfo_GetByteSize + size_t GetByteSize() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstExternalInitializerInfo = + detail::ConstExternalInitializerInfoImpl>; + +/** \brief Wrapper around ::OrtExternalInitializerInfo + * + */ +struct ExternalInitializerInfo : detail::ConstExternalInitializerInfoImpl { + using Base = detail::ConstExternalInitializerInfoImpl; + using Base::Base; + + explicit ExternalInitializerInfo(std::nullptr_t) {} + explicit ExternalInitializerInfo(OrtExternalInitializerInfo* p) + : detail::ConstExternalInitializerInfoImpl{p} {} + + ConstExternalInitializerInfo GetConst() const { return ConstExternalInitializerInfo{this->p_}; } + + ///< Wraps OrtApi::CreateExternalInitializerInfo + ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size); + + ///< Wrapper around CreateExternalInitializerInfo that does not throw an exception. + static Status Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size, + /*out*/ ExternalInitializerInfo& out); +}; + namespace detail { template struct KeyValuePairsImpl : Ort::detail::Base { @@ -1357,11 +1401,23 @@ struct ModelCompilationOptions : detail::Base { ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + + ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + ModelCompilationOptions& SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + + ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc + ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation - ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags + ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags + + ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. @@ -2366,15 +2422,46 @@ struct ArenaCfg : detail::Base { // Custom OPs (only needed to implement custom OPs) // +namespace detail { +// Need to define a templated ConstOpAttr with const members +template +struct ConstOpAttrImpl : Base { + using B = detail::Base; + using B::B; + + // Wraps OrtApi::OpAttr_GetName + std::string GetName() const; + // Wraps OrtApi::OpAttr_GetType + OrtOpAttrType GetType() const; + + // Wraps OrtApi::ReadAttr for a single value + // This does not support Tensor Attribute + // Call GetTensorAttributeAsOrtValue() instead. + template + Status GetValue(R& out) const; + + // Wraps OrtApi::ReadAttr for an array of values + template + Status GetValueArray(std::vector& out) const; + // Wraps OrtApi::OpAttr_GetTensorAttributeAsOrtValue + Status GetTensorAttributeAsOrtValue(Value&) const; +}; +} // namespace detail + +using ConstOpAttr = detail::ConstOpAttrImpl>; + /// /// This struct provides life time management for custom op attribute /// -struct OpAttr : detail::Base { - using Base = detail::Base; +struct OpAttr : detail::ConstOpAttrImpl { + using Base = detail::ConstOpAttrImpl; using Base::Base; + OpAttr() = default; // Enable storing it in the container for resize() explicit OpAttr(std::nullptr_t) {} OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); + + ConstOpAttr GetConst() const { return ConstOpAttr{this->p_}; } }; /** @@ -2720,7 +2807,7 @@ struct ShapeInferContext { Strings GetAttrStrings(const char* attr_name); private: - const OrtOpAttr* GetAttrHdl(const char* attr_name) const; + ConstOpAttr GetAttrHdl(const char* attr_name) const; const OrtApi* ort_api_; OrtShapeInferContext* ctx_; std::vector input_shapes_; @@ -2871,48 +2958,114 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; +// Forward declaration to resolve circular dependency +// on ConstNode +struct ValueInfoConsumerProducerInfo; + namespace detail { template -struct ValueInfoImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstValueInfoImpl : Base { + using B = Base; using B::B; - std::string Name() const; + /// < A wrapper around OrtApi::GetValueInfoName + std::string GetName() const; + /// < A wrapper around OrtApi::GetValueInfoTypeInfo ConstTypeInfo TypeInfo() const; + ///< Wraps OrtApi::ValueInfo_GetProducerNode + ValueInfoConsumerProducerInfo GetProducerNode() const; + /// < A wrapper around OrtApi::ValueInfo_GetValueConsumers + std::vector GetConsumers() const; + /// < A wrapper around OrtApi::ValueInfo_GetInitializerValue + Status GetInitializer(ConstValue& value) const; + /// < A wrapper around OrtApi::ValueInfo_GetExternalInitializerInfo + Status GetExternalInitializerInfo(ExternalInitializerInfo& info) const; + /// < A wrapper around OrtApi::ValueInfo_IsRequiredGraphInput + bool IsRequiredGraphInput() const; + /// < A wrapper around OrtApi::ValueInfo_IsOptionalGraphInput + bool IsOptionalGraphInput() const; + /// < A wrapper around OrtApi::ValueInfo_IsGraphOutput + bool IsGraphOutput() const; + /// < A wrapper around OrtApi::ValueInfo_IsConstantInitializer + bool IsConstantInitializer() const; + /// < A wrapper around OrtApi::ValueInfo_IsFromOuterScope + bool IsFromOuterScope() const; }; } // namespace detail // Const object holder that does not own the underlying object -using ConstValueInfo = detail::ValueInfoImpl>; +using ConstValueInfo = detail::ConstValueInfoImpl>; /** \brief Wrapper around ::OrtValueInfo * */ -struct ValueInfo : detail::ValueInfoImpl { +struct ValueInfo : detail::ConstValueInfoImpl { + ValueInfo() = default; // Same thing as with nullptr explicit ValueInfo(std::nullptr_t) {} ///< No instance is created /// Take ownership of a pointer created by C API - explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + explicit ValueInfo(OrtValueInfo* p) : ConstValueInfoImpl{p} {} +#if !defined(ORT_MINIMAL_BUILD) // Create ValueInfo for a tensor explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); - +#endif ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } }; +// Forward declaration +struct AttrNameSubgraph; + namespace detail { +// Forward decl template -struct NodeImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstGraphImpl; + +template +struct ConstNodeImpl : Base { + using B = Base; using B::B; + + // GetInputs() const; + // GetOutputs() const; + // GetImplicitInputs() const; + // GetAttributes() const; + // GetSubgraphs() const; + // > GetGraph() const; + // >; + /** \brief Wrapper around ::OrtNode * */ -struct Node : detail::NodeImpl { - explicit Node(std::nullptr_t) {} ///< No instance is created - explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API +struct Node : detail::ConstNodeImpl { + Node() = default; // Same thing as with nullptr + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : ConstNodeImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) Node(const std::string& operator_name, const std::string& operator_domain, @@ -2939,22 +3092,78 @@ struct Node : detail::NodeImpl { #endif // !defined(ORT_MINIMAL_BUILD) }; +// Return struct for some of ValueInfo APIs. +// Must be declared after ConstNode is available. +struct ValueInfoConsumerProducerInfo { + ConstNode node; + // either producer output or consumer output index + // producer is unsigned only, output can be -1 + int64_t index; +}; + +// Represents a return value for Graph::GetOperatorSets() +struct OperatorSet { + std::string domain; + int64_t version; +}; + namespace detail { template -struct GraphImpl : Ort::detail::Base { - using B = Ort::detail::Base; +struct ConstGraphImpl : Base { + using B = Base; + using B::B; + + // GetModelPath() const; + // GetOperatorSets() const; + // GetInputs() const; + // GetOutputs() const; + // GetInitializers() const; + // GetNodes() const; + // & nodes) const; + // +struct GraphImpl : ConstGraphImpl { + using B = ConstGraphImpl; using B::B; #if !defined(ORT_MINIMAL_BUILD) + // & inputs); + // & outputs); + // >; + +// Return value for Node API +// Must be declared after ConstGraph +struct AttrNameSubgraph { + std::string attr_name; + ConstGraph sub_graph; +}; + /** \brief Wrapper around ::OrtGraph * */ @@ -2962,25 +3171,26 @@ struct Graph : detail::GraphImpl { explicit Graph(std::nullptr_t) {} ///< No instance is created explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) + // >; namespace detail { template -struct ModelImpl : Ort::detail::Base { +struct ModelImpl : detail::Base { using B = Ort::detail::Base; using B::B; #if !defined(ORT_MINIMAL_BUILD) + // >; +using UnownedModel = detail::ModelImpl>; /** \brief Wrapper around ::OrtModel * @@ -2992,10 +3202,9 @@ struct Model : detail::ModelImpl { explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API #if !defined(ORT_MINIMAL_BUILD) + //< Wraps GetModelEditorApi().CreateModel() explicit Model(const std::vector& opsets); #endif - - ConstModel GetConst() const { return ConstModel{this->p_}; } }; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 05c86ae4e0c58..59979189eed0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -571,6 +571,42 @@ inline PrepackedWeightsContainer::PrepackedWeightsContainer() { ThrowOnError(GetApi().CreatePrepackedWeightsContainer(&this->p_)); } +namespace detail { + +template +inline const std::basic_string ConstExternalInitializerInfoImpl::GetFilePath() const { + return GetApi().ExternalInitializerInfo_GetFilePath(this->p_); +} + +template +inline int64_t ConstExternalInitializerInfoImpl::GetFileOffset() const { + return GetApi().ExternalInitializerInfo_GetFileOffset(this->p_); +} + +template +inline size_t ConstExternalInitializerInfoImpl::GetByteSize() const { + return GetApi().ExternalInitializerInfo_GetByteSize(this->p_); +} +} // namespace detail + +inline ExternalInitializerInfo::ExternalInitializerInfo(const ORTCHAR_T* filepath, int64_t file_offset, + size_t byte_size) { + ThrowOnError(GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &this->p_)); +} + +inline Status ExternalInitializerInfo::Create(const ORTCHAR_T* filepath, int64_t file_offset, size_t byte_size, + /*out*/ ExternalInitializerInfo& out) { + OrtExternalInitializerInfo* info = nullptr; + OrtStatus* status = GetApi().CreateExternalInitializerInfo(filepath, file_offset, byte_size, &info); + if (status != nullptr) { + return Status{status}; + } + + out = ExternalInitializerInfo(info); + + return Status{nullptr}; +} + namespace detail { template inline const char* KeyValuePairsImpl::GetValue(const char* key) const { @@ -1003,6 +1039,16 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalI return *this; } +inline ModelCompilationOptions& +ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + this->p_, + get_initializer_location_func, + state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator, @@ -1011,6 +1057,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, + void* state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this->p_, write_func, state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( bool embed_ep_context_in_model) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( @@ -1019,11 +1071,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( return *this; } -inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) { +inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags)); return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel( + GraphOptimizationLevel graph_optimization_level) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_, + graph_optimization_level)); + return *this; +} + namespace detail { template @@ -1759,7 +1818,7 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat #if !defined(ORT_MINIMAL_BUILD) inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) { - ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); + ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model, options, &this->p_)); } // static @@ -2475,6 +2534,172 @@ inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data)); } +namespace detail { + +template +constexpr OrtOpAttrType TypeToAttrType(); + +template <> +inline constexpr OrtOpAttrType TypeToAttrType() { + return OrtOpAttrType::ORT_OP_ATTR_INT; +} + +template <> +inline constexpr OrtOpAttrType TypeToAttrType() { + return OrtOpAttrType::ORT_OP_ATTR_FLOAT; +} + +template +inline constexpr OrtOpAttrType TypeToAttrsType(); + +template <> +inline constexpr OrtOpAttrType TypeToAttrsType() { + return OrtOpAttrType::ORT_OP_ATTR_INTS; +} + +template <> +inline constexpr OrtOpAttrType TypeToAttrsType() { + return OrtOpAttrType::ORT_OP_ATTR_FLOATS; +} + +inline Status CheckAttrType(const OrtOpAttr* attr, OrtOpAttrType requested_type) { + OrtOpAttrType type; + Ort::Status status(GetApi().OpAttr_GetType(attr, &type)); + if (!status.IsOK()) return status; + if (requested_type != type) { + std::string msg = "Attribute type mismatch: expected " + std::to_string(requested_type) + + ", but got " + std::to_string(type); + return Ort::Status(msg.c_str(), OrtErrorCode::ORT_INVALID_ARGUMENT); + } + return Ort::Status{}; +} + +inline size_t GetDataSize(const OrtOpAttr* attr, OrtOpAttrType attr_type) { + size_t result{}; + // Ignore the status here because we check the data type so the error should only be about + // the size + [[maybe_unused]] Status status{GetApi().ReadOpAttr(attr, attr_type, nullptr, 0, &result)}; + return result; +} + +template +Ort::Status GetNumericValue(const OrtOpAttr* attr, T& out) { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + size_t size{}; + return Ort::Status{GetApi().ReadOpAttr(attr, TypeToAttrType(), &out, sizeof(out), &size)}; +} + +template +struct GetValueImpl { + static Status GetValue(const OrtOpAttr* attr, T& out) { + return GetNumericValue(attr, out); + } + static Status GetValues(const OrtOpAttr* attr, std::vector& out) { + // Api deficiency when it comes to value arrays. It is not possible + // to tell if the error is due to the type mismatch or the size + // so we check the type first, and then ignore the status of the size check + constexpr auto deduced_type = TypeToAttrsType(); + auto status = CheckAttrType(attr, deduced_type); + if (!status.IsOK()) return status; + auto size = GetDataSize(attr, deduced_type); + std::vector result; + if (size > 0) { + result.resize(size / sizeof(T)); + status = Status{GetApi().ReadOpAttr( + attr, deduced_type, result.data(), size, &size)}; + if (!status.IsOK()) return status; + } + out.swap(result); + return status; + } +}; + +// Create GetValueImpl specializations for std::string +template <> +struct GetValueImpl { + static Status GetValue(const OrtOpAttr* attr, std::string& out) { + // Api deficiency when it comes to value arrays. It is not possible + // to tell if the error is due to the type mismatch or the size + // so we check the type first, and then ignore the status of the size check + auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRING); + if (!status.IsOK()) return status; + auto size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRING); + std::string result; + if (size > 0) { + result.resize(size); + // some compilers in use do not support std::string::data() non-const + auto* buffer = &result[0]; + status = Status{GetApi().ReadOpAttr( + attr, OrtOpAttrType::ORT_OP_ATTR_STRING, buffer, size, &size)}; + if (!status.IsOK()) return status; + } + out.swap(result); + return status; + } + static Status GetValues(const OrtOpAttr* attr, std::vector& out) { + auto status = CheckAttrType(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + if (!status.IsOK()) return status; + + std::vector result; + size_t total_buffer_size = GetDataSize(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS); + if (total_buffer_size > 0) { + // Create a temporary buffer to hold the string data + std::vector buffer(total_buffer_size); + status = Status{GetApi().ReadOpAttr(attr, OrtOpAttrType::ORT_OP_ATTR_STRINGS, buffer.data(), + total_buffer_size, &total_buffer_size)}; + if (!status.IsOK()) return status; + + const char* data = buffer.data(); + const char* end = data + total_buffer_size; + while (data < end) { + result.emplace_back(data); + data += result.back().size() + 1; // Move past the null terminator + } + } + out.swap(result); + return status; + } +}; + +template +template +inline Status ConstOpAttrImpl::GetValue(R& out) const { + return GetValueImpl::GetValue(this->p_, out); +} + +template +template +inline Status ConstOpAttrImpl::GetValueArray(std::vector& out) const { + return GetValueImpl::GetValues(this->p_, out); +} + +template +inline Status ConstOpAttrImpl::GetTensorAttributeAsOrtValue(Value& out) const { + OrtValue* tensor_value = nullptr; + auto status = Status(GetApi().OpAttr_GetTensorAttributeAsOrtValue(this->p_, &tensor_value)); + if (!status.IsOK()) return status; + out = Value{tensor_value}; + return status; +} + +template +inline std::string ConstOpAttrImpl::GetName() const { + const char* name = nullptr; + ThrowOnError(GetApi().OpAttr_GetName(this->p_, &name)); + if (name != nullptr) { + return name; + } + return {}; +} + +template +inline OrtOpAttrType ConstOpAttrImpl::GetType() const { + OrtOpAttrType type; + ThrowOnError(GetApi().OpAttr_GetType(this->p_, &type)); + return type; +} +} // namespace detail + inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) { Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_)); } @@ -2775,115 +3000,69 @@ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shap } inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out)); - return i; + auto attr = GetAttrHdl(attr_name); + int64_t value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting int attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); + } + return value; } inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - int64_t i = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); - if (status) { - size_t num_i = out / sizeof(int64_t); - ShapeInferContext::Ints ints(num_i, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); - return ints; - } else { - if (out == 0u) { - return {}; - } - return {i}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Ints result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting ints attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } inline float ShapeInferContext::GetAttrFloat(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out)); - return f; + auto attr = GetAttrHdl(attr_name); + float value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting float attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); + } + return value; } inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - float f = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); - if (status) { - size_t num_f = out / sizeof(float); - ShapeInferContext::Floats floats(num_f, 0); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); - return floats; - } else { - if (out == 0u) { - return {}; - } - return {f}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Floats result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting floats attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); - return std::string{chars.data(), out}; - } else { - return {c}; + auto attr = GetAttrHdl(attr_name); + std::string value; + Status status = attr.GetValue(value); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting string attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return value; } inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) { - const auto* attr = GetAttrHdl(attr_name); - char c = {}; - size_t out = {}; - // first call to get the bytes needed - // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. - // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). - // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. - auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); - if (status) { - std::vector chars(out, '\0'); - Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out)); - ShapeInferContext::Strings strings; - char* char_st = chars.data(); - char* char_ed = char_st + out; - while (char_st < char_ed) { - strings.emplace_back(char_st); - while (*char_st != '\0') { - char_st++; - } - char_st++; - } - return strings; - } else { - if (out == 0u) { - return {}; - } - return {std::string{c}}; + auto attr = GetAttrHdl(attr_name); + ShapeInferContext::Strings result; + auto status = attr.GetValueArray(result); + if (!status.IsOK()) { + ORT_CXX_API_THROW("Getting strings attribute failed: " + status.GetErrorMessage(), status.GetErrorCode()); } + return result; } -inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const { +inline ConstOpAttr ShapeInferContext::GetAttrHdl(const char* attr_name) const { const OrtOpAttr* attr_hdl = {}; Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl)); - return attr_hdl; + return ConstOpAttr{attr_hdl}; } namespace detail { @@ -2897,6 +3076,136 @@ inline std::vector StringsToCharPtrs(const std::vector } } // namespace detail +namespace detail { +template +inline size_t ConstNodeImpl::GetId() const { + size_t id; + ThrowOnError(GetApi().Node_GetId(this->p_, &id)); + return id; +} + +template +inline std::string ConstNodeImpl::GetName() const { + const char* name; + ThrowOnError(GetApi().Node_GetName(this->p_, &name)); + return std::string(name); +} + +template +inline std::string ConstNodeImpl::GetOperatorType() const { + const char* type; + ThrowOnError(GetApi().Node_GetOperatorType(this->p_, &type)); + return std::string(type); +} + +template +inline std::string ConstNodeImpl::GetDomain() const { + const char* domain; + ThrowOnError(GetApi().Node_GetDomain(this->p_, &domain)); + return std::string(domain); +} + +template +inline int ConstNodeImpl::GetSinceVersion() const { + int since_version; + ThrowOnError(GetApi().Node_GetSinceVersion(this->p_, &since_version)); + return since_version; +} + +template +inline std::vector ConstNodeImpl::GetInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetOutputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumOutputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetImplicitInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Node_GetNumImplicitInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Node_GetImplicitInputs(this->p_, reinterpret_cast(result.data()), + num_vi)); + } + return result; +} + +template +inline std::vector ConstNodeImpl::GetAttributes() const { + static_assert(sizeof(const OrtOpAttr*) == sizeof(ConstOpAttr), "Must be the same size"); + size_t num_attrs; + ThrowOnError(GetApi().Node_GetNumAttributes(this->p_, &num_attrs)); + std::vector attrs; + if (num_attrs > 0) { + attrs.resize(num_attrs); + ThrowOnError(GetApi().Node_GetAttributes(this->p_, reinterpret_cast(attrs.data()), num_attrs)); + } + return attrs; +} + +template +inline Status ConstNodeImpl::GetAttributeByName(const std::string& name, ConstOpAttr& out) const { + const OrtOpAttr* attr = nullptr; + auto status = Status(GetApi().Node_GetAttributeByName(this->p_, name.c_str(), &attr)); + out = ConstOpAttr{attr}; + return status; +} + +template +inline std::vector ConstNodeImpl::GetSubgraphs() const { + size_t num_graphs; + ThrowOnError(GetApi().Node_GetNumSubgraphs(this->p_, &num_graphs)); + std::vector result; + if (num_graphs > 0) { + std::vector sub_graphs(num_graphs); + std::vector attr_names(num_graphs); + ThrowOnError(GetApi().Node_GetSubgraphs(this->p_, sub_graphs.data(), num_graphs, attr_names.data())); + result.reserve(num_graphs); + for (size_t i = 0; i < num_graphs; ++i) { + result.push_back({std::string(attr_names[i]), ConstGraph{sub_graphs[i]}}); + } + } + return result; +} + +template +inline ConstGraph ConstNodeImpl::GetGraph() const { + const OrtGraph* graph; + ThrowOnError(GetApi().Node_GetGraph(this->p_, &graph)); + return ConstGraph{graph}; +} + +template +inline std::string ConstNodeImpl::GetEpName() const { + const char* name; + ThrowOnError(GetApi().Node_GetEpName(this->p_, &name)); + return std::string(name); +} + +} // namespace detail + #if !defined(ORT_MINIMAL_BUILD) // static inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, @@ -2938,97 +3247,294 @@ inline Node::Node(const std::string& operator_name, const std::string& operator_ std::vector empty_attributes; Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); } +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +#endif // !defined(ORT_MINIMAL_BUILD) -inline Graph::Graph() { - ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +namespace detail { +template +inline std::string ConstValueInfoImpl::GetName() const { + const char* p = nullptr; + ThrowOnError(GetApi().GetValueInfoName(this->p_, &p)); + return std::string(p); } -inline Model::Model(const std::vector& opsets) { - std::vector domains; - std::vector versions; - domains.reserve(opsets.size()); - versions.reserve(opsets.size()); +template +inline ConstTypeInfo ConstValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} - for (const auto& pair : opsets) { - domains.push_back(pair.first.c_str()); - versions.push_back(pair.second); +template +inline ValueInfoConsumerProducerInfo ConstValueInfoImpl::GetProducerNode() const { + ValueInfoConsumerProducerInfo info; + const OrtNode* producer; + size_t index; + ThrowOnError(GetApi().ValueInfo_GetValueProducer(this->p_, &producer, &index)); + info.node = ConstNode(producer); + info.index = static_cast(index); + return info; +} + +template +inline std::vector ConstValueInfoImpl::GetConsumers() const { + size_t num = 0; + ThrowOnError(GetApi().ValueInfo_GetValueNumConsumers(this->p_, &num)); + std::vector out; + if (num > 0) { + std::vector nodes(num); + std::vector indices(num); + ThrowOnError(GetApi().ValueInfo_GetValueConsumers(this->p_, nodes.data(), indices.data(), num)); + out.reserve(num); + for (size_t i = 0; i < num; ++i) { + out.push_back({ConstNode{nodes[i]}, indices[i]}); + } } + return out; +} - ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +template +inline Status ConstValueInfoImpl::GetInitializer(ConstValue& value) const { + const OrtValue* out = nullptr; + auto status = Status(GetApi().ValueInfo_GetInitializerValue(this->p_, &out)); + if (!status.IsOK()) return status; + value = ConstValue{out}; + return status; } -inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { - ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_)); +template +inline Status ConstValueInfoImpl::GetExternalInitializerInfo(ExternalInitializerInfo& info) const { + OrtExternalInitializerInfo* out = nullptr; + auto status = Status(GetApi().ValueInfo_GetExternalInitializerInfo(this->p_, &out)); + if (!status.IsOK()) return status; + info = ExternalInitializerInfo{out}; + return status; } -#endif // !defined(ORT_MINIMAL_BUILD) -namespace detail { -template <> -inline std::string ValueInfoImpl::Name() const { - const char* name = nullptr; - ThrowOnError(GetApi().GetValueInfoName(this->p_, &name)); - return name; +template +inline bool ConstValueInfoImpl::IsRequiredGraphInput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsRequiredGraphInput(this->p_, &out)); + return out; } -template <> -inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { - const OrtTypeInfo* type_info = nullptr; - ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info)); - return ConstTypeInfo{type_info}; +template +inline bool ConstValueInfoImpl::IsOptionalGraphInput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsOptionalGraphInput(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsGraphOutput() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsGraphOutput(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsConstantInitializer() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsConstantInitializer(this->p_, &out)); + return out; +} + +template +inline bool ConstValueInfoImpl::IsFromOuterScope() const { + bool out = false; + ThrowOnError(GetApi().ValueInfo_IsFromOuterScope(this->p_, &out)); + return out; +} + +template +inline ModelMetadata ConstGraphImpl::GetModelMetadata() const { + OrtModelMetadata* out; + ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); + return ModelMetadata{out}; +} + +template +inline std::string ConstGraphImpl::GetName() const { + const char* name; + ThrowOnError(GetApi().Graph_GetName(this->p_, &name)); + return std::string(name); +} + +template +inline std::basic_string ConstGraphImpl::GetModelPath() const { + const ORTCHAR_T* path; + ThrowOnError(GetApi().Graph_GetModelPath(this->p_, &path)); + return std::basic_string(path); +} + +template +inline int64_t ConstGraphImpl::GetOnnxIRVersion() const { + int64_t version; + ThrowOnError(GetApi().Graph_GetOnnxIRVersion(this->p_, &version)); + return version; +} + +template +inline std::vector ConstGraphImpl::GetOperatorSets() const { + size_t num_opsets; + ThrowOnError(GetApi().Graph_GetNumOperatorSets(this->p_, &num_opsets)); + std::vector result; + if (num_opsets > 0) { + std::vector domains; + std::vector versions; + domains.resize(num_opsets); + versions.resize(num_opsets); + ThrowOnError(GetApi().Graph_GetOperatorSets(this->p_, domains.data(), versions.data(), num_opsets)); + result.reserve(num_opsets); + for (size_t i = 0; i < num_opsets; ++i) { + result.push_back({domains[i], versions[i]}); + } + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetInputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumInputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetInputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetOutputs() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumOutputs(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetOutputs(this->p_, reinterpret_cast(result.data()), num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetInitializers() const { + static_assert(sizeof(const OrtValueInfo*) == sizeof(ConstValueInfo)); + size_t num_vi; + ThrowOnError(GetApi().Graph_GetNumInitializers(this->p_, &num_vi)); + std::vector result; + if (num_vi > 0) { + result.resize(num_vi); + ThrowOnError(GetApi().Graph_GetInitializers(this->p_, reinterpret_cast(result.data()), + num_vi)); + } + return result; +} + +template +inline std::vector ConstGraphImpl::GetNodes() const { + static_assert(sizeof(const OrtNode*) == sizeof(ConstNode)); + size_t num_nodes; + ThrowOnError(GetApi().Graph_GetNumNodes(this->p_, &num_nodes)); + std::vector result; + if (num_nodes > 0) { + result.resize(num_nodes); + ThrowOnError(GetApi().Graph_GetNodes(this->p_, reinterpret_cast(result.data()), num_nodes)); + } + return result; +} + +template +inline ConstNode ConstGraphImpl::GetParentNode() const { + const OrtNode* parent; + ThrowOnError(GetApi().Graph_GetParentNode(this->p_, &parent)); + return ConstNode{parent}; +} + +template +inline Graph ConstGraphImpl::GetGraphView(const std::vector& nodes) const { + OrtGraph* graph_viewer; + std::vector inputs_ptrs; + inputs_ptrs.reserve(nodes.size()); + std::transform(nodes.begin(), nodes.end(), std::back_inserter(inputs_ptrs), + [](ConstNode n) -> const OrtNode* { return n; }); + ThrowOnError(GetApi().Graph_GetGraphView(this->p_, inputs_ptrs.data(), + nodes.size(), &graph_viewer)); + return Graph{graph_viewer}; } #if !defined(ORT_MINIMAL_BUILD) -template <> -inline void GraphImpl::SetInputs(std::vector& inputs) { +template +inline void GraphImpl::SetInputs(std::vector& inputs) { std::vector inputs_ptrs; inputs_ptrs.reserve(inputs.size()); std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + ThrowOnError(GetModelEditorApi().SetGraphInputs(this->p_, inputs_ptrs.data(), inputs_ptrs.size())); // Graph now owns the inputs std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); } -template <> -inline void GraphImpl::SetOutputs(std::vector& outputs) { +template +inline void GraphImpl::SetOutputs(std::vector& outputs) { std::vector outputs_ptrs; outputs_ptrs.reserve(outputs.size()); std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); - ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + ThrowOnError(GetModelEditorApi().SetGraphOutputs(this->p_, outputs_ptrs.data(), outputs_ptrs.size())); // Graph now owns the outputs std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); } -template <> -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { +template +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { // Graph takes ownership of `initializer` - ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); + // On error the ownership is not transferred. + ThrowOnError(GetModelEditorApi().AddInitializerToGraph(this->p_, name.c_str(), initializer, data_is_external)); + initializer.release(); } -template <> -inline void GraphImpl::AddNode(Node& node) { +template +inline void GraphImpl::AddNode(Node& node) { // Graph takes ownership of `node` - ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release())); + ThrowOnError(GetModelEditorApi().AddNodeToGraph(this->p_, node.release())); } template -inline ModelMetadata GraphImpl::GetModelMetadata() const { - OrtModelMetadata* out; - ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out)); - return ModelMetadata{out}; -} - -template <> -inline void ModelImpl::AddGraph(Graph& graph) { +inline void ModelImpl::AddGraph(Graph& graph) { // Model takes ownership of `graph` - ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release())); + ThrowOnError(GetModelEditorApi().AddGraphToModel(this->p_, graph.release())); } #endif // !defined(ORT_MINIMAL_BUILD) } // namespace detail + +#if !defined(ORT_MINIMAL_BUILD) +inline Graph::Graph() { + ThrowOnError(GetModelEditorApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} +#endif + } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index 672103bedc437..bbd6a43bb7a41 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -12,4 +12,7 @@ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; // Prefix for execution provider compatibility information stored in model metadata. // Used when generating EP context models to store compatibility strings for each EP. // Full key format: "ep_compatibility_info." -static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; \ No newline at end of file +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; + +// Key for the execution provider library path (for dynamically loaded EPs) +static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 7de9cfa14927d..550502cf3bc48 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -33,6 +33,7 @@ OrtCompileApiFlags, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 + OrtExternalInitializerInfo, # noqa: F401 OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 1a737f3a9d251..34410a5f42630 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -106,6 +106,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -271,6 +273,8 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h new file mode 100644 index 0000000000000..eae96c186d471 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include + +namespace onnxruntime { +namespace contrib { + +enum class ActivationType { + Relu = 0, + Gelu = 1, + Silu = 2, + Identity = 3, + SwiGLU = 4, +}; + +class MoEBaseCPU { + protected: + MoEBaseCPU(const OpKernelInfo& op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ActivationType::Identity; + } else if (activation_type_str == "swiglu") { + activation_type_ = ActivationType::SwiGLU; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + + normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; + + use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault("use_sparse_mixer", 0) == 1; + if (use_sparse_mixer_) { + ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); + } + + swiglu_fusion_ = op_kernel_info.GetAttrOrDefault("swiglu_fusion", 0); + swiglu_limit_ = op_kernel_info.GetAttrOrDefault("swiglu_limit", std::numeric_limits::infinity()); + activation_alpha_ = op_kernel_info.GetAttrOrDefault("activation_alpha", 1.0f); + activation_beta_ = op_kernel_info.GetAttrOrDefault("activation_beta", 0.0f); + } + + bool normalize_routing_weights_; + bool use_sparse_mixer_; + int64_t k_; + ActivationType activation_type_; + float activation_alpha_; + float activation_beta_; + float swiglu_limit_; + int64_t swiglu_fusion_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h new file mode 100644 index 0000000000000..e494719464d20 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "core/framework/tensor_shape.h" +#include "core/util/shape_checker.h" + +namespace onnxruntime { +namespace contrib { + +enum class MoEParallelType { + None = 0, + EP = 1, + TP = 2, + EPAndTP = 3, +}; + +struct MoEParameters { + MoEParameters() = default; + + explicit MoEParameters(int64_t tensor_shards) + : tensor_shards(tensor_shards) {} + + int64_t num_rows{0}; + int64_t num_experts{0}; + int64_t local_num_experts{0}; + int64_t hidden_size{0}; + int64_t inter_size{0}; + + MoEParallelType parallel_type{MoEParallelType::None}; + int64_t tensor_shards{1}; +}; +namespace moe_helper { + +template +Status CheckInputs(MoEParameters& parameters, + const Tensor* input, // required + const Tensor* router_probs, // required + const Tensor* fc1_experts_weights, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_experts_weights, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_experts_weights, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const bool is_fused_swiglu) { + // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + ASSERT_TENSOR_2D_OR_3D(input); + ASSERT_TENSOR_3D(fc1_experts_weights); + ASSERT_TENSOR_3D(fc2_experts_weights); + ASSERT_TENSOR_2D(router_probs); + + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + int64_t hidden_size = input_dims[input_dims.size() - 1]; + int64_t local_num_experts = fc1_experts_weights_dims[0]; + int64_t num_experts = router_probs_dims[1]; + int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size; + + const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + + // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. + const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; + + if (legacy_shape) { + // legacy shape does not match column major memory layout. This is for backward compatibility. + CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); + CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); + CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + } else { + CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); + CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); + CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + } + + CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); + + CHECK_TENSOR_SHAPE(fc1_experts_bias, num_experts, fc1_inter_size); + CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size); + CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size); + + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + + if (fc3_experts_weights == nullptr) { + ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr); + } else { + ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales + } + + parameters.num_rows = num_rows; + parameters.num_experts = num_experts; + parameters.local_num_experts = local_num_experts; + parameters.hidden_size = hidden_size; + parameters.inter_size = inter_size; + if (num_experts == local_num_experts) { + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::None; + } else { + parameters.parallel_type = MoEParallelType::TP; + } + } else if (num_experts > local_num_experts) { + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::EP; + } else { + parameters.parallel_type = MoEParallelType::EPAndTP; + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_experts must be greater than or equal to local_num_experts, got ", num_experts, + " and ", local_num_experts); + } + + return Status::OK(); +} + +} // namespace moe_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc new file mode 100644 index 0000000000000..5c6c3b919b572 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -0,0 +1,400 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_quantization_cpu.h" + +#include "core/framework/allocator.h" +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" + +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. +// The source quantized weights are stored as a row-major representation of the transposed +// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t /*block_size*/, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data) { + const float zero_point = num_bits == 8 ? 128.0f : 8.0f; + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; ++c) { + const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; + // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. + const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) + dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + } + } + } +} + +template +QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) + : OpKernel(op_kernel_info), + MoEBaseCPU(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); + ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, + "Attribute 'expert_weight_bits' must be 4 or 8."); + block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); +} + +template +Status QMoECPU::Compute(OpKernelContext* context) const { + // --- 1. Get Inputs and Attributes --- + const auto* input = context->Input(0); + const auto* router_probs = context->Input(1); + const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_scales = context->Input(3); + const auto* fc1_experts_bias = context->Input(4); + const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_scales = context->Input(6); + const auto* fc2_experts_bias = context->Input(7); + const auto* fc3_experts_weights = context->Input(8); + const auto* fc3_scales = context->Input(9); + const auto* fc3_experts_bias = context->Input(10); + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, fc1_scales, + fc2_experts_weights, fc2_experts_bias, fc2_scales, + fc3_experts_weights, fc3_experts_bias, fc3_scales, + expert_weight_bits_ == 4 ? 2 : 1, + true)); + + if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); + } + + const auto& input_shape = input->Shape(); + const int64_t num_tokens = moe_params.num_rows; + const int64_t hidden_size = moe_params.hidden_size; + const int64_t inter_size = moe_params.inter_size; + const int64_t num_experts = moe_params.num_experts; + const int64_t fc1_out_features = inter_size * (swiglu_fusion_ > 0 ? 2 : 1); + + auto* output = context->Output(0, input_shape); + auto* tp = context->GetOperatorThreadPool(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + + const T* input_data = input->Data(); + const T* router_probs_data = router_probs->Data(); + + // --- 2. Routing Logic: Assign tokens to experts --- + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + const_cast(router_logits_float), + static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_probs_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + // Parallelize the routing logic to improve performance for large token batches. + // Minor performance regression for single-token decoding is an acceptable trade-off + int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. + std::vector> sorted_logits(static_cast(num_experts)); + std::vector top_k_exp(static_cast(k_)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { + sorted_logits[static_cast(j)] = {logits[j], j}; + } + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + float max_logit = -std::numeric_limits::infinity(); + for (int64_t j = 0; j < k_; ++j) { + if (sorted_logits[static_cast(j)].first > max_logit) { + max_logit = sorted_logits[static_cast(j)].first; + } + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_exp[static_cast(j)] = std::exp(sorted_logits[static_cast(j)].first - max_logit); + sum_exp += top_k_exp[static_cast(j)]; + } + + float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; + if (route_scale[route_idx] > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + }); + + // Merge the maps from each thread into a single global map. + std::vector> expert_token_map(static_cast(num_experts)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + } + } + } + + // --- 3. Parallel Expert Computation --- + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data), + const_cast(input_float), + static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data); + } + + int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + + // Pre-calculate workspace size per thread to avoid allocations inside the loop + size_t max_tokens_per_expert = 0; + for (const auto& tokens : expert_token_map) { + if (tokens.size() > max_tokens_per_expert) { + max_tokens_per_expert = tokens.size(); + } + } + + const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); + const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); + const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); + const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); + const size_t B2_dequant_size = static_cast(hidden_size * inter_size); + const size_t bias1_size = static_cast(fc1_out_features); + const size_t bias2_size = static_cast(hidden_size); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; + auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); + float* workspace = workspace_ptr.get(); + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = static_cast(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) { + continue; + } + + const int64_t num_expert_tokens = routes.size(); + + // Partition the workspace for the current expert + float* A1 = thread_workspace; + float* C1 = A1 + num_expert_tokens * hidden_size; + float* A2 = C1 + num_expert_tokens * fc1_out_features; + float* C2 = A2 + num_expert_tokens * inter_size; + float* B1_dequant = C2 + num_expert_tokens * hidden_size; + float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; + float* bias1_float = B2_dequant + hidden_size * inter_size; + float* bias2_float = bias1_float + fc1_out_features; + + // --- Gather input tokens for the current expert --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + memcpy(A1 + i * hidden_size, + input_float + token_idx * hidden_size, + static_cast(hidden_size) * sizeof(float)); + } + + // --- FC1 GEMM (X * W1^T) --- + DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), + fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), + block_size_, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), + 1.0f, A1, static_cast(hidden_size), + B1_dequant, static_cast(hidden_size), + 0.0f, C1, static_cast(fc1_out_features), + nullptr); + + const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; + if (B1_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + } else { + memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + for (int64_t i = 0; i < num_expert_tokens; ++i) { + for (int64_t j = 0; j < fc1_out_features; ++j) { + C1[i * fc1_out_features + j] += bias1_float[j]; + } + } + } + + // --- Activation --- + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + + // --- FC2 GEMM (A2 * W2^T) --- + DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), + fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), + block_size_, expert_weight_bits_, + hidden_size, inter_size, B2_dequant); + + MlasGemm(CblasNoTrans, CblasTrans, + static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), + 1.0f, A2, static_cast(inter_size), + B2_dequant, static_cast(inter_size), + 0.0f, C2, static_cast(hidden_size), + nullptr); + + const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; + if (B2_bias) { + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + } else { + memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + } + + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t route_idx = routes[static_cast(i)]; + const int64_t token_idx = route_idx / k_; + const float weight = route_scale[route_idx]; + + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { + // Skip this token to prevent buffer overflow + continue; + } + + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; + const float* src = C2 + i * hidden_size; + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + } + } + } + }); + + // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- + auto accumulate = [&](float* buffer) { + memset(buffer, 0, output_buffer_size * sizeof(float)); + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + for (size_t j = 0; j < output_buffer_size; ++j) { + buffer[j] += thread_local_outputs[thread_offset + j]; + } + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + // --- 5. Convert final float buffer to output type T --- + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { // T is float + accumulate(output->MutableData()); + } + + return Status::OK(); +} + +// Explicit template instantiation +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; +template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); +template Status QMoECPU::Compute(OpKernelContext* context) const; + +// Kernel Registration +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, float, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, kMSDomain, 1, MLFloat16, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoECPU); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h new file mode 100644 index 0000000000000..890580e051a8e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +/** + * @brief QMoE is the templated CPU implementation of the Quantized Mixture of Experts operator. + * + * This kernel supports both float and MLFloat16 data types for activations, scales, and outputs. + * It parallelizes expert computation using the ONNX Runtime thread pool and minimizes memory + * usage through on-the-fly block dequantization of weights. + * + * @tparam T The data type for the kernel (float or MLFloat16). + */ +template +class QMoECPU final : public OpKernel, public MoEBaseCPU { + public: + explicit QMoECPU(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + int64_t expert_weight_bits_; + int64_t block_size_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc new file mode 100644 index 0000000000000..2c59210bfabd4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_utils.h" +#include +#include +#include "core/common/common.h" + +namespace onnxruntime { +namespace contrib { + +float ApplyActivation(float x, ActivationType activation_type) { + switch (activation_type) { + case ActivationType::Relu: + return std::max(0.0f, x); + case ActivationType::Gelu: + return 0.5f * x * (1.0f + std::tanh(0.7978845608f * (x + 0.044715f * x * x * x))); + case ActivationType::Silu: + return x * (1.0f / (1.0f + std::exp(-x))); + case ActivationType::Identity: + return x; + case ActivationType::SwiGLU: + // SwiGLU is a special case handled by ApplySwiGLUActivation, this is just a placeholder + return x; + default: + return x; + } +} + +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit) { + if (is_interleaved_format) { + for (int64_t i = 0; i < inter_size; ++i) { + float gate_val = input_data[2 * i]; + float linear_val = input_data[2 * i + 1]; + + gate_val = std::min(gate_val, clamp_limit); + linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); + + float sigmoid_arg = activation_alpha * gate_val; + float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); + float swish_out = gate_val * sigmoid_out; + + output_data[i] = swish_out * (linear_val + activation_beta); + } + } else { + ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation"); + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.h b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h new file mode 100644 index 0000000000000..de238e8d7ae66 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +float ApplyActivation(float x, ActivationType activation_type); + +void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t inter_size, bool is_interleaved_format, + float activation_alpha, float activation_beta, float clamp_limit); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 1a4a38282fcc1..51252dc2b0467 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -197,19 +197,33 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All has_zp_input_, nullptr, nullptr); is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { -#ifdef MLAS_TARGET_AMD64_IX86 - if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, - has_zp_input_, nullptr, nullptr); - is_packed = false; - } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { - auto zptr = tensor.Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, - has_zp_input_, zptr, nullptr); - is_packed = false; + // Packing scales and zero points + bool should_pack_scale_and_zp_inputs = [&]() { +#if defined(MLAS_TARGET_AMD64_IX86) + return true; +#else + return (nbits_ == 8); +#endif + }(); + + if (should_pack_scale_and_zp_inputs) { + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, + has_zp_input_, nullptr, nullptr); + is_packed = false; + } + + // Packing zero_point + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, + has_zp_input_, zptr, nullptr); + is_packed = false; + } } -#elif defined(MLAS_TARGET_ARM64) + +#if defined(MLAS_TARGET_ARM64) if (input_idx == InputIndex::scales && packed_b_ != nullptr && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_)) { scales_are_packed_ = true; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 1a4a63de38790..93d802ca05b42 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -71,15 +71,21 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params(tensor_shards_); - MoEQuantType quant_type = MoEQuantType::None; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, nullptr, + fc2_experts_weights, fc2_experts_bias_optional, nullptr, + fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, + 1, // no quantization so pack size is 1 + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index f3346d4513261..36d6fc378d45e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -15,6 +15,10 @@ using namespace onnxruntime::common; ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name) #define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \ ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name) +#define CUDA_MS_OP_TWO_TYPED_CLASS_NAME(ver, type1, type2, name) \ + ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, name) +#define CUDA_MS_OP_THREE_TYPED_CLASS_NAME(ver, type1, type2, type3, name) \ + ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type1, type2, type3, name) #define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \ ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name) @@ -92,7 +96,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE); -class CUDA_MS_OP_CLASS_NAME(1, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention); @@ -184,6 +190,25 @@ class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, uint8_t, BFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, UInt4x2, BFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int32_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, float, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, MLFloat16, int64_t, GatherBlockQuantized); +class CUDA_MS_OP_THREE_TYPED_CLASS_NAME(1, Int4x2, BFloat16, int64_t, GatherBlockQuantized); + #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); #endif @@ -307,7 +332,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -404,6 +431,24 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 36127054cfd5e..d5ad8161e100e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -52,6 +52,7 @@ enum class ActivationType { Gelu, GeGLU, ReGLU, SiGLU, + SwiGLU, Identity, InvalidType }; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 0000000000000..5f0a71147b366 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 0000000000000..4a84581127156 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 0000000000000..6c23127955ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index ef1f97b9e57a2..f855092670bc3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -53,6 +53,8 @@ #include "cutlass_heuristic.h" #include "moe_gemm_kernels.h" +#include + #include #include #include @@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float, bfloat16"); static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || @@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, T>::type>::type; using ElementType = ElementType_; using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, WeightType>::type>::type; using CutlassWeightType = CutlassWeightType_; @@ -391,12 +392,10 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { + } else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels. dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else { - ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -478,6 +477,7 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { + // Swiglu will use Identity to call this function so we not need to handle it here. switch (activation_type) { case ActivationType::Relu: run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index bfbe1d81b1c15..ce8c0270f5c32 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -38,12 +38,96 @@ #include "moe_kernel.h" +#include #include #include #include +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" + namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; + +// SwiGLU with interleaved is like the following python code using PyTorch: +// dim = x.shape[-1] +// x = x.view(-1, dim // 2, 2) +// x_glu, x_linear = x[..., 0], x[..., 1] +// y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) +template +__global__ void swiglu_kernel_interleaved(T* output, T const* input, int intermediate_size, int num_rows, float alpha, float limit) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + float glu = static_cast(row_input[2 * i]); + float linear = static_cast(row_input[2 * i + 1]); + + if constexpr (HasLimit) { + glu = fminf(glu, limit); + linear = fminf(fmaxf(linear, -limit), limit); + } + + float sigmoid_arg = alpha * glu; + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = glu * sigmoid_out; + row_output[i] = static_cast(swish_out * (linear + 1.f)); + } +} + +// Non interleaved version of SwiGLU kernel, which splits each row into two chunks of same size. +template +__global__ void swiglu_kernel_chunked(T* output, T const* input, int intermediate_size, int num_rows, float alpha, float limit) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + float glu = static_cast(row_input[i]); + float linear = static_cast(row_input[i + intermediate_size]); + + if constexpr (HasLimit) { + glu = fminf(glu, limit); + linear = fminf(fmaxf(linear, -limit), limit); + } + + float sigmoid_arg = alpha * glu; + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = glu * sigmoid_out; + row_output[i] = static_cast(swish_out * (linear + 1.f)); + } +} + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float alpha, float limit, cudaStream_t stream) { + if (num_rows == 0) { + return; + } + dim3 block(std::min(intermediate_size, 1024)); + dim3 grid(num_rows); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size); + + if constexpr (IsInterLeaved) { + swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, alpha, limit); + } else { + swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, alpha, limit); + } + + DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size); +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -456,7 +540,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ if (normalize_routing_weights && k_idx == k - 1) { #pragma unroll for (int ki = 0; ki < k; ++ki) { - output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + float old_val = static_cast(output[idx - ki]); + output[idx - ki] = T(old_val / output_row_sum); } } } @@ -666,9 +751,14 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer) - : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { + : activation_type_(activation_type), + has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights), + use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -695,8 +785,16 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + + size_t bytes_for_fc1_result; + if (activation_type_ == ActivationType::SwiGLU) { + // Space for both fc1_result_ and act_result_. + bytes_for_fc1_result = (2 * interbuf_size + interbuf_size) * sizeof(T); + } else { + bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + } + + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)); sorter_.update_num_experts(static_cast(num_experts)); size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; @@ -705,7 +803,7 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro bytes_for_intermediate_and_sorting += remaining_bytes; } - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + total_ws_bytes += bytes_for_intermediate_and_sorting; return total_ws_bytes; } @@ -725,27 +823,49 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + char* current_ptr = reinterpret_cast(total_rows_before_expert_ + padded_experts); + + if (activation_type_ == ActivationType::SwiGLU) { + // fc1_result_ is used for GEMM1 output (2 * inter_size) + fc1_result_ = reinterpret_cast(current_ptr); + current_ptr += 2 * interbuf_size * sizeof(T); + + // act_result_ is used for SwiGLU output (inter_size) + act_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); + + ORT_ENFORCE(!has_fc3_, "SwiGLU activation is not supported with fc3"); + } else { + fc1_result_ = reinterpret_cast(current_ptr); + act_result_ = nullptr; // No extra buffer for activation since it is done inplace. + current_ptr += interbuf_size * sizeof(T); + } + if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + fc3_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc3_result_ = nullptr; } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(current_ptr); } else { softmax_out_ = nullptr; } } namespace { - -struct __align__(8) Half4 { +typedef struct __CUDA_ALIGN__(8) { half2 x; half2 y; -}; +} half2_2; + +typedef struct __CUDA_ALIGN__(8) { + __nv_bfloat162 x; + __nv_bfloat162 y; +} __nv_bfloat162_2; // TODO(wy): move to common header template @@ -756,7 +876,11 @@ struct T4 { }; template <> struct T4 { - using Type = Half4; + using Type = half2_2; +}; +template <> +struct T4<__nv_bfloat16> { + using Type = __nv_bfloat162_2; }; template @@ -769,6 +893,10 @@ template <> struct T2 { using Type = half2; }; +template <> +struct T2<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } @@ -785,15 +913,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha #endif // TODO(wy): use cuda common header and investigate pipeline build issue. -inline __device__ Half4 operator*(const Half4 a, const Half4 b) { +inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) - Half4 result; + half2_2 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#else + return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; +#endif +} + +inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \ + ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + __nv_bfloat162_2 result; result.x = a.x * b.x; result.y = a.y * b.y; return result; #else - return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; + return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; #endif } @@ -880,8 +1020,54 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, - // expanded_active_expert_rows); + if (fc1_activation_type == ActivationType::SwiGLU) { + T* gemm1_output_buffer = fc1_result_; + T* swiglu_output_buffer = act_result_; + + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, + fc1_scales, + fc1_expert_biases, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + 2 * inter_size, + hidden_size, + local_num_experts, + ActivationType::Identity, + stream); + + constexpr bool swiglu_interleaved = true; + constexpr bool swiglu_has_limit = true; + constexpr float swiglu_alpha = 1.702f; + constexpr float swiglu_limit = 7.0f; + invokeSwiGLU( + swiglu_output_buffer + total_past_rows_ * inter_size, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + inter_size, + static_cast(total_covered_rows_), + swiglu_alpha, + swiglu_limit, + stream); + + moe_gemm_runner_.moe_gemm( + swiglu_output_buffer + total_past_rows_ * inter_size, + fc2_expert_weights, + fc2_scales, + nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + hidden_size, + inter_size, + local_num_experts, + stream); + + // No fc3 for SwiGLU + return; + } + moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -1151,18 +1337,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa int, bool, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, bool, bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int, + int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; +// For qMoE: template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; +template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; // ===================== Specializations for init routing ========================= template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, cudaStream_t); template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int, + cudaStream_t); // ==================== Specializations for final routing =================================== template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, @@ -1177,5 +1371,10 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl const float*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t); + +template void invokeSwiGLU(float*, float const*, int, int, float, float, cudaStream_t); +template void invokeSwiGLU(half*, half const*, int, int, float, float, cudaStream_t); } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index c457b608decbf..de11d357a8c07 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -54,7 +54,10 @@ static inline size_t pad_to_multiple_of_16(size_t input) { template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, int* indices, int* source_row, int num_rows, int num_experts, int k, - cudaStream_t stream); + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream); + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream); class CubKeyValueSorter { public: @@ -109,7 +112,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -157,8 +160,10 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* act_result_; T* fc3_result_; + ActivationType activation_type_; bool has_fc3_; bool normalize_routing_weights_; bool use_sparse_mixer_; @@ -173,14 +178,4 @@ class CutlassMoeFCRunner { std::vector total_rows_before_expert_host_; }; -template -class CutlassMoeFCRunner::value>> { - public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); - - size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { - return 0; - } -}; - } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index c5352d931ce2c..a5b9d483d5ad1 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -3,6 +3,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "moe.h" using namespace onnxruntime::cuda; @@ -20,6 +21,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { @@ -37,19 +39,25 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params; - MoEQuantType quant_type = MoEQuantType::None; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); - - typedef typename ToCudaType::MappedType CudaT; + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, nullptr, + fc2_experts_weights, fc2_experts_bias_optional, nullptr, + fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, + 1, // no quantization so pack size is 1 + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + + using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 6b65557444a66..5f0c30b16a8f4 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -7,206 +7,13 @@ #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" #include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" +#include "contrib_ops/cpu/moe/moe_helper.h" namespace onnxruntime { namespace contrib { namespace cuda { -enum class MoEParallelType { - None = 0, - EP = 1, - TP = 2, - EPAndTP = 3, -}; - -enum class MoEQuantType { - None = 0, - UINT4 = 1, - UINT8 = 2, -}; - -struct MoEParameters { - MoEParameters() {} - explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {} - int64_t num_rows; - int64_t num_experts; - int64_t local_num_experts; - int64_t hidden_size; - int64_t inter_size; - - MoEParallelType parallel_type; - int64_t tensor_shards{1}; -}; - class MoEBase { - public: - Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input, - const Tensor* router_probs, const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional) const { - const auto& input_dims = input->Shape().GetDims(); - const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); - - int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; - int64_t hidden_size = input_dims[input_dims.size() - 1]; - int64_t local_num_experts = fc1_experts_weights_dims[0]; - int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = fc2_experts_weights_dims[1]; - - if (fc1_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", - fc1_experts_weights_dims.size()); - } - if (fc2_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", - fc2_experts_weights_dims.size()); - } - if (fc1_experts_weights_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", - fc1_experts_weights_dims[1], " and ", hidden_size); - } - if (fc2_experts_weights_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[1] must be equal to inter_size, got ", - fc2_experts_weights_dims[1], " and ", inter_size); - } - - const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; - if (fc1_experts_weights_dims[2] != inter_size / coe) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], " and ", inter_size); - } - if (fc2_experts_weights_dims[2] != hidden_size / coe) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); - } - - if (router_probs_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", - router_probs_dims.size()); - } - if (router_probs_dims[0] != num_rows) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", - router_probs_dims[0], " and ", num_rows); - } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { - const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); - const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); - if (fc1_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", - fc1_experts_bias_dims.size()); - } - if (fc2_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", - fc2_experts_bias_dims.size()); - } - if (fc1_experts_bias_dims[0] != local_num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", - fc1_experts_bias_dims[0], " and ", local_num_experts); - } - if (fc2_experts_bias_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], - " and ", num_experts); - } - if (fc1_experts_bias_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); - } - if (fc2_experts_bias_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], - " and ", hidden_size); - } - } - - if (fc3_experts_weights_optional != nullptr && - fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ", - fc3_experts_weights_optional->Shape(), " and ", TensorShape(fc1_experts_weights_dims)); - } - - if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr && - fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) { - return ORT_MAKE_STATUS( - ONNXRUNTIME, INVALID_ARGUMENT, "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ", - fc3_experts_bias_optional->Shape(), " and ", fc1_experts_bias_optional->Shape()); - } - - parameters.num_rows = num_rows; - parameters.num_experts = num_experts; - parameters.local_num_experts = local_num_experts; - parameters.hidden_size = hidden_size; - parameters.inter_size = inter_size; - if (num_experts == local_num_experts) { - if (parameters.tensor_shards == 1) { - parameters.parallel_type = MoEParallelType::None; - } else { - parameters.parallel_type = MoEParallelType::TP; - } - } else if (num_experts > local_num_experts) { - if (parameters.tensor_shards == 1) { - parameters.parallel_type = MoEParallelType::EP; - } else { - parameters.parallel_type = MoEParallelType::EPAndTP; - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_experts must be greater than or equal to local_num_experts, got ", num_experts, - " and ", local_num_experts); - } - - return Status::OK(); - } - - Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, - const Tensor* fc3_experts_scales, int64_t num_experts, int64_t hidden_size, - int64_t inter_size) const { - const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims(); - const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims(); - - if (fc1_experts_scales_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ", - fc1_experts_scales->Shape().GetDims().size()); - } - if (fc1_experts_scales_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", - fc1_experts_scales_dims[0], " and ", num_experts); - } - if (fc1_experts_scales_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", - fc1_experts_scales_dims[1], " and ", inter_size); - } - if (fc2_experts_scales_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", - fc2_experts_scales->Shape().GetDims().size()); - } - if (fc2_experts_scales_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ", - fc2_experts_scales_dims[0], " and ", num_experts); - } - if (fc2_experts_scales_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ", - fc2_experts_scales_dims[1], " and ", hidden_size); - } - if (fc3_experts_scales != nullptr && fc1_experts_scales_dims != fc3_experts_scales->Shape().GetDims()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc3_experts_scales must be equal to fc1_experts_scales, got ", - fc3_experts_scales->Shape(), " and ", TensorShape(fc1_experts_scales_dims)); - } - - return Status::OK(); - } - protected: MoEBase(const OpKernelInfo& op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); @@ -219,6 +26,8 @@ class MoEBase { activation_type_ = ort_fastertransformer::ActivationType::Gelu; } else if (activation_type_str == "silu") { activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "swiglu") { + activation_type_ = ort_fastertransformer::ActivationType::SwiGLU; } else if (activation_type_str == "identity") { activation_type_ = ort_fastertransformer::ActivationType::Identity; } else { diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc new file mode 100644 index 0000000000000..bad44b260b7b2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/quantization/gather_block_quantized.h" +#include "contrib_ops/cuda/quantization/gather_block_quantized.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +#define REGISTER_GATHERBLOCKQUANTIZED(T1, T2, Tind) \ + ONNX_OPERATOR_THREE_TYPED_KERNEL_EX( \ + GatherBlockQuantized, \ + kMSDomain, 1, \ + T1, T2, Tind, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + GatherBlockQuantized); + +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, float, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, float, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, float, int64_t); + +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, MLFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, MLFloat16, int64_t); + +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, BFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, BFloat16, int64_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(Int4x2, BFloat16, int64_t); + +template +GatherBlockQuantized::GatherBlockQuantized(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(info.GetAttr("bits", &bits_).IsOK()); + + block_size_ = info.GetAttrOrDefault("block_size", 0); + gather_axis_ = info.GetAttrOrDefault("gather_axis", 0); + quantize_axis_ = info.GetAttrOrDefault("quantize_axis", 0); + + // If block size is set, it has to be no smaller than 16 and must be power of 2 + // block_size_ & (block_size_ - 1) == 0 checks if block_size_ only has 1 bit set + ORT_ENFORCE(block_size_ == 0 || (block_size_ >= 16 && ((block_size_ & (block_size_ - 1)) == 0))); +} + +template +Status GatherBlockQuantized::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* data = ctx->Input(0); + const Tensor* indices = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + auto data_shape = data->Shape().GetDims(); + auto data_rank = data->Shape().NumDimensions(); + + auto indices_shape = indices->Shape().GetDims(); + auto indices_rank = indices->Shape().NumDimensions(); + + ORT_ENFORCE(quantize_axis_ == data_rank - 1); + + TensorShapeVector output_shape; + output_shape.reserve(data_rank - 1 + indices_rank); + + // Dimension after gather axis + int64_t after_gather_dim = 1; + + // Dimension of indices + int64_t ind_dim = 1; + + // 1) dims before gather_axis + for (int64_t i = 0; i < gather_axis_; ++i) { + output_shape.push_back(data_shape[i]); + } + + // 2) all of indices.shape + for (auto dim : indices_shape) { + output_shape.push_back(dim); + ind_dim *= dim; + } + + // 3) dims after gather_axis + for (int64_t i = gather_axis_ + 1; i < data_rank; ++i) { + output_shape.push_back(data_shape[i]); + after_gather_dim *= data_shape[i]; + } + + // Special int4‐in‐uint8 packing tweak: expand the last dim by components + if constexpr (std::is_same_v) { + uint32_t components = 8 / static_cast(bits_); + if (components > 1) { + output_shape.back() *= components; + } + } + + Tensor* output = ctx->Output(0, TensorShape(output_shape)); + + int64_t N = 1; + for (auto dim : output_shape) { + N *= dim; + } + + const auto* data_ptr = data->Data(); + const auto* indices_ptr = indices->Data(); + const T1* zero_points_ptr = nullptr; + if (zero_points != nullptr) { + zero_points_ptr = zero_points->Data(); + } + + GatherBlockQuantizedParam param; + param.stream = Stream(ctx); + param.after_gather_dim = after_gather_dim; + param.gather_axis_dim = data_shape[gather_axis_]; + param.ind_dim = ind_dim; + param.bits = bits_; + param.block_size = block_size_; + param.gather_axis = gather_axis_; + param.N = N; + + const auto dequantized_type = scales->GetElementType(); + if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } else if (dequantized_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + const auto* scales_ptr = static_cast(scales->DataRaw()); + auto* output_ptr = static_cast(output->MutableDataRaw()); + LaunchGatherBlockQuantizedKernel(data_ptr, indices_ptr, scales_ptr, zero_points_ptr, output_ptr, param); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu new file mode 100644 index 0000000000000..39286c63e9a08 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "gather_block_quantized.cuh" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__device__ inline int64_t get_val(const T1* data, int64_t idx, int64_t bits, bool sign) { + const uint32_t mask = (1U << bits) - 1; + const int64_t elems_per_byte = 8 / bits; + const int64_t byte_idx = idx / elems_per_byte; + const int64_t bit_offset = (idx % elems_per_byte) * bits; + const uint8_t byte = reinterpret_cast(data)[byte_idx]; + int64_t val = (byte >> bit_offset) & mask; + + // Sign-extend based on bit width + if (sign) { + if (val & (1 << (bits - 1))) { + val |= -1LL << bits; + } + } + + return val; +} + +template +__global__ void GatherBlockQuantizedKernel( + const T1* data, // packed 4-bit codes, one code per element + const Tind* indices, + const T2* scales, // one float scale per block + const T1* zero_points, // packed 4-bit zero-points, one per block + T2* output, + int64_t after_gather_dim, + int64_t gather_axis_dim, + int64_t ind_dim, + int64_t bits, + int64_t block_size, + int64_t gather_axis, + int64_t N, + bool sign) { + int64_t out_idx = blockDim.x * blockIdx.x + threadIdx.x; + if (out_idx >= N) return; + + // compute which input element this thread corresponds to: + int64_t idx_before = out_idx / (after_gather_dim * ind_dim); + int64_t idx_after = out_idx % after_gather_dim; + int64_t idx = (out_idx % (after_gather_dim * ind_dim)) / after_gather_dim; + int64_t idx_at_g = indices[idx]; + int64_t in_idx = idx_before * gather_axis_dim * after_gather_dim + idx_at_g * after_gather_dim + idx_after; + + int64_t block_id = in_idx / block_size; + + // unpack zero_point for this block: + int64_t offset = 0; + if (zero_points) { + offset = get_val(zero_points, block_id, bits, sign); + } + + // unpack the raw quantized code for this element: + int64_t weight = get_val(data, in_idx, bits, sign); + + // apply dequantization: + output[out_idx] = static_cast(weight - offset) * scales[block_id]; +} + +template +void LaunchGatherBlockQuantizedKernel(const T1* data, + const Tind* indices, + const T2* scales, + const T1* zero_points, + T2* output, + GatherBlockQuantizedParam param) { + // Require quant_axis is last dim + int blocksPerGrid = (int)(ceil(static_cast(param.N) / GridDim::maxThreadsPerBlock)); + bool sign = std::is_same::value; + + GatherBlockQuantizedKernel<<>>(data, indices, scales, zero_points, output, + param.after_gather_dim, param.gather_axis_dim, param.ind_dim, param.bits, param.block_size, param.gather_axis, param.N, sign); +} + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); + +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int32_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const uint8_t*, const int64_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int32_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const UInt4x2*, const int64_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int32_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); +template void LaunchGatherBlockQuantizedKernel(const Int4x2*, const int64_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh new file mode 100644 index 0000000000000..f5dea3b1f2d9d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cuh @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +struct GatherBlockQuantizedParam { + cudaStream_t stream; + int64_t after_gather_dim; + int64_t gather_axis_dim; + int64_t ind_dim; + int64_t bits; + int64_t block_size; + int64_t gather_axis; + int64_t N; +}; + +template +void LaunchGatherBlockQuantizedKernel(const T1* data, + const Tind* indices, + const T2* scales, + const T1* zero_points, + T2* output, + GatherBlockQuantizedParam param); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h new file mode 100644 index 0000000000000..7718b6dd06765 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GatherBlockQuantized final : public CudaKernel { + public: + GatherBlockQuantized(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t bits_; + int64_t block_size_; + int64_t gather_axis_; + int64_t quantize_axis_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 4dd5a079d1a29..dcf32bb3c5ae4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -5,6 +5,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/quantization/moe_quantization.h" +#include "core/providers/cuda/cuda_type_conversion.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -14,16 +15,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", BuildKernelDefConstraints()) \ - .TypeConstraint("T1", BuildKernelDefConstraints()), \ - QMoE); - -REGISTER_KERNEL() - namespace { template struct ToCudaTypeWrapper : public ToCudaType {}; @@ -40,27 +31,29 @@ struct ToCudaTypeWrapper { } // anonymous namespace -QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +template +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); } +template template -Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional, - const cudaDeviceProp& device_prop) const { +Status QMoE::QuantizedMoEImpl(OpKernelContext* context, + MoEParameters& moe_params, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional, + const Tensor* fc1_scales, + const Tensor* fc2_scales, + const Tensor* fc3_scales_optional, + const cudaDeviceProp& device_prop) const { auto stream = context->GetComputeStream(); const int sm = device_prop.major * 10 + device_prop.minor; @@ -68,10 +61,10 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - using T = MLFloat16; - using CudaT = typename ToCudaType::MappedType; + using CudaT = typename OrtToCudaType::type; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, fc3_experts_weights_optional != nullptr, normalize_routing_weights_, use_sparse_mixer_); @@ -136,7 +129,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, return Status::OK(); } -Status QMoE::ComputeInternal(OpKernelContext* context) const { +template +Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); @@ -149,20 +143,21 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_scales_optional = context->Input(9); const Tensor* fc3_experts_bias_optional = context->Input(10); - MoEQuantType quant_type = expert_weight_bits_ == 4 ? MoEQuantType::UINT4 : MoEQuantType::UINT8; MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); - ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, - moe_params.hidden_size, moe_params.inter_size)); + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, + fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, + fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, + expert_weight_bits_ == 4 ? 2 : 1, + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); #if defined(__GNUC__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters. #endif - if (quant_type == MoEQuantType::UINT4) { + if (expert_weight_bits_ == 4) { using CudaWeightT = typename ToCudaTypeWrapper::MappedType; return QuantizedMoEImpl(context, moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, @@ -183,6 +178,32 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { #endif } +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h index c0164576d7c7f..c4698a1f277ef 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -14,6 +14,7 @@ namespace cuda { using namespace onnxruntime::cuda; +template class QMoE final : public CudaKernel, public MoEBase { public: explicit QMoE(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc new file mode 100644 index 0000000000000..abfd3cf89cecf --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/common/common.h" +#include "core/framework/ep_context_options.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +namespace onnxruntime { +namespace epctx { +// class ModelGenOptions + +ModelGenOptions::ModelGenOptions() = default; + +ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { + enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + + std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (!output_model_path.empty()) { + output_model_location = std::filesystem::path(output_model_path); + } else { + output_model_location = std::monostate{}; + } + + std::string external_initializers_file_path = config_options.GetConfigOrDefault( + kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); + if (!external_initializers_file_path.empty()) { + ExternalInitializerFileInfo ext_info = {}; + ext_info.file_path = external_initializers_file_path; + ext_info.size_threshold = 0; + initializers_location = std::move(ext_info); + } + + embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; +} + +bool ModelGenOptions::HasOutputModelLocation() const { + return !std::holds_alternative(output_model_location); +} + +const std::filesystem::path* ModelGenOptions::TryGetOutputModelPath() const { + return std::get_if(&output_model_location); +} + +const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { + return std::get_if(&output_model_location); +} + +const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const { + return std::get_if(&output_model_location); +} + +bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const { + return std::holds_alternative(initializers_location); +} + +const ExternalInitializerFileInfo* ModelGenOptions::TryGetExternalInitializerFileInfo() const { + return std::get_if(&initializers_location); +} + +const InitializerHandler* ModelGenOptions::TryGetInitializerHandler() const { + return std::get_if(&initializers_location); +} + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h new file mode 100644 index 0000000000000..6643516bfb4c3 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.h @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" + +namespace onnxruntime { +namespace epctx { +/// +/// Holds the buffer that will store the output model and the allocator used to allocate the memory. +/// +struct BufferHolder { + void** buffer_ptr = nullptr; + size_t* buffer_size_ptr = nullptr; + AllocatorPtr buffer_allocator = nullptr; +}; + +/// +/// Holds the opaque stream state and the write function that ORT calls to write out the output model. +/// +struct BufferWriteFuncHolder { + OrtWriteBufferFunc write_func = nullptr; + void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. +}; + +/// +/// Holds path and size threshold used to write out initializers to an external file. +/// +struct ExternalInitializerFileInfo { + std::filesystem::path file_path; + size_t size_threshold = 0; +}; + +/// +/// Holds function and state provided by user to handle initializer data (i.e., write to stream or embed in model). +/// +struct InitializerHandler { + OrtGetInitializerLocationFunc handle_initializer_func = nullptr; + void* state = nullptr; +}; + +/// +/// Stores EPContext model generation options. Used in SessionOptions. +/// +struct ModelGenOptions { + // Action to take if the output model does not have compiled (EPContext) nodes. + enum class ActionIfNoCompiledNodes { + // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior + // to maintain compatibility. The explicit compile API does *not* use this action. + kDontGenerateModel = 0, + + // Generate an output model even if it doesn't have compiled nodes. + // The explicit Compile API defaults to this value. + kGenerateModel, + + // Return an error if the model does not have compiled nodes. + // The explicit Compile API can be configured to this value. + kReturnError, + }; + + ModelGenOptions(); + + // Initializes from string key/value pairs in session config options. + explicit ModelGenOptions(const ConfigOptions& config_options); + + bool enable = false; + bool error_if_output_file_exists = true; + bool error_if_no_compiled_nodes = false; + bool embed_ep_context_in_model = false; + ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; + + std::variant // Function to write the output model to a user's stream. + output_model_location = std::monostate{}; + + std::variant // Custom function called for every initializer to determine location. + initializers_location = std::monostate{}; + + bool HasOutputModelLocation() const; + const std::filesystem::path* TryGetOutputModelPath() const; + const BufferHolder* TryGetOutputModelBuffer() const; + const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const; + + bool AreInitializersEmbeddedInOutputModel() const; + const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const; + const InitializerHandler* TryGetInitializerHandler() const; +}; + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_utils.cc b/onnxruntime/core/framework/ep_context_utils.cc new file mode 100644 index 0000000000000..3f02c54538526 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_utils.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) +#include +#include +#include "core/framework/ep_context_utils.h" +#include "core/framework/error_code_helper.h" +#include "core/graph/model_saving_options.h" + +namespace onnxruntime { +namespace epctx { + +// Serialize an EPContext model into a onnx::ModelProto. +Status EpContextModelToProto(const onnxruntime::Model& ep_context_model, + const std::filesystem::path& validated_model_path, + const epctx::ModelGenOptions& ep_context_gen_options, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { + // Handle case where initializers are stored inline within the ONNX model. + if (ep_context_gen_options.AreInitializersEmbeddedInOutputModel()) { + // if no external ini file specified, set force_embed_external_ini to true to avoid intermediate file creation + // and force all initializers embed into the ONNX file. + ModelSavingOptions model_saving_options{/*size_threshold*/ SIZE_MAX}; + model_saving_options.force_embed_external_ini = true; + + model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(std::filesystem::path{}, + validated_model_path, + model_saving_options); + return Status::OK(); + } + + // Handle case where initializers (with size > threshold) are stored in an external file. + if (const epctx::ExternalInitializerFileInfo* ext_info = ep_context_gen_options.TryGetExternalInitializerFileInfo(); + ext_info != nullptr) { + ModelSavingOptions model_saving_options{ext_info->size_threshold}; + + model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(ext_info->file_path, + validated_model_path, + model_saving_options); + return Status::OK(); + } + + // Handle case where user specified a custom handler function that determines how each initializer is saved. + if (const epctx::InitializerHandler* custom_handler = ep_context_gen_options.TryGetInitializerHandler(); + custom_handler != nullptr) { + ORT_RETURN_IF_ERROR(ep_context_model.ToGraphProtoWithCustomInitializerHandling( + custom_handler->handle_initializer_func, + custom_handler->state, + model_proto)); + return Status::OK(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected location for initializers while generating ", + validated_model_path); +} + +// +// OutStreamBuf class: +// + +OutStreamBuf::OutStreamBuf(BufferWriteFuncHolder write_func_holder) + : write_func_holder_(write_func_holder), buffer_(65536) { + setp(buffer_.data(), buffer_.data() + buffer_.size()); +} + +OutStreamBuf::~OutStreamBuf() { + sync(); +} + +// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_. +std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { + if (sync() == -1) { + return traits_type::eof(); + } + + if (ch != traits_type::eof()) { + *pptr() = static_cast(ch); + pbump(1); + } + + return ch; +} + +// Flushes the entire buffer_ to the user's write function. +int OutStreamBuf::sync() { + if (!last_status_.IsOK()) { + return -1; + } + + std::ptrdiff_t num_bytes = pptr() - pbase(); + if (num_bytes == 0) { + return 0; + } + + // Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes. + if (num_bytes > std::numeric_limits::max()) { + num_bytes = std::numeric_limits::max(); + } + + char* ptr = pbase(); + + Status status = Status::OK(); + + ORT_TRY { + status = ToStatusAndRelease(write_func_holder_.write_func(write_func_holder_.stream_state, + ptr, num_bytes)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); + }); + } + + if (!status.IsOK()) { + last_status_ = std::move(status); + return -1; + } + + pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ + return 0; +} + +} // namespace epctx +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/ep_context_utils.h b/onnxruntime/core/framework/ep_context_utils.h new file mode 100644 index 0000000000000..b3c76565982ff --- /dev/null +++ b/onnxruntime/core/framework/ep_context_utils.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include + +#include "core/common/status.h" +#include "core/framework/ep_context_options.h" +#include "core/graph/model.h" + +namespace onnxruntime { +namespace epctx { + +/// +/// Serialize an EPContext model into a onnx::ModelProto based on the provided options. +/// +/// The EP Context model to serialize. +/// The path into which to save the model. May be empty if serialized into a +/// buffer or output stream. +/// The model generation options. +/// Output parameter set to the serialized onnx::ModelProto. +/// A status indicating success or an error. +Status EpContextModelToProto(const onnxruntime::Model& ep_context_model, + const std::filesystem::path& validated_model_path, + const epctx::ModelGenOptions& ep_context_gen_options, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto); + +// Class that wraps the user's OrtBufferWriteFunc function to enable use with +// C++'s std::ostream. +// Example: +// BufferWriteFuncHolder write_func_holder{write_func, stream_state}; +// std::unique_ptr out_stream_buf = std::make_unique(write_func_holder); +// std::ostream out_stream(out_stream_buf.get()); +class OutStreamBuf : public std::streambuf { + public: + explicit OutStreamBuf(BufferWriteFuncHolder write_func_holder); + ~OutStreamBuf(); + + const Status& GetStatus() const { + return last_status_; + } + + protected: + int_type overflow(int_type ch) override; + int sync() override; + + private: + BufferWriteFuncHolder write_func_holder_{}; + std::vector buffer_; + Status last_status_{}; +}; + +} // namespace epctx +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 421e5a6db51b7..43caf4766d5c0 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -5,10 +5,12 @@ #include #include +#include #include "core/common/inlined_containers.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/ep_context_utils.h" #include "core/framework/execution_providers.h" #include "core/framework/func_kernel.h" #include "core/framework/kernel_lookup.h" @@ -20,9 +22,9 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" -#include "core/graph/model_saving_options.h" -#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/util/protobuf_parsing_utils.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -766,6 +768,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide } // Validate the ep_context_path to make sure it is file path and check whether the file exist already +// TODO: Move function to ep_context_utils.h/cc static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_path, const std::filesystem::path& model_path, std::filesystem::path& context_cache_path, @@ -794,9 +797,10 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ return Status::OK(); } +// TODO: Move function to ep_context_utils.h/cc static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const logging::Logger& logger) { InlinedVector all_ep_context_nodes; for (const auto& ep : execution_providers) { @@ -807,11 +811,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers if (all_ep_context_nodes.size() < 1) { auto action_if_no_compiled_nodes = ep_context_gen_options.action_if_no_compiled_nodes; - ORT_RETURN_IF(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError, + ORT_RETURN_IF(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError, "Unable to compile any nodes. Check that the session EPs support compilation and can execute " "at least one subgraph in the model."); - if (action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { + if (action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kDontGenerateModel) { LOGS(logger, WARNING) << "Unable to compile any nodes. ONNX Runtime will not generate a compiled model. " "Either the session EPs do not support compilation or the model is already compiled."; // Note: this path is only taken if a model is compiled with the original compilation approach that uses @@ -821,7 +825,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } // Assert so that this is caught in a test in DEBUG builds (in case a new enum value is added) - assert(action_if_no_compiled_nodes == EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel); + assert(action_if_no_compiled_nodes == epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel); LOGS(logger, INFO) << "Unable to compile any nodes but will still generate an output model. " "Either the session EPs do not support compilation or the model is already compiled."; } @@ -835,15 +839,17 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr && - ep_context_gen_options.output_model_buffer_size_ptr != nullptr && - ep_context_gen_options.output_model_buffer_allocator != nullptr; + const epctx::BufferHolder* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); + const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); + const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); - std::filesystem::path context_cache_path; - if (!saving_to_buffer || !graph.ModelPath().empty()) { - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, + std::filesystem::path valid_output_model_path; + if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) { + std::filesystem::path output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr + : std::filesystem::path(""); + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path, graph.ModelPath(), - context_cache_path, + valid_output_model_path, ep_context_gen_options.error_if_output_file_exists)); } @@ -910,10 +916,11 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } + ORT_RETURN_IF_ERROR(ep_graph.Resolve()); + // Generate EP compatibility strings for OrtEp types and add to model metadata // At this point, the graph has been populated with all the EPContext nodes { - ORT_RETURN_IF_ERROR(ep_graph.Resolve()); const GraphViewer graph_viewer(ep_graph); for (const auto& ep : execution_providers) { try { @@ -938,39 +945,60 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } } - size_t ini_size_threshold = ep_context_gen_options.output_external_initializer_size_threshold; - std::filesystem::path external_ini_path = ep_context_gen_options.output_external_initializers_file_path; - bool force_embed_external_ini = false; - if (external_ini_path.empty()) { - // if no external ini file specified, set force_embed_external_ini to true to avoid intermedia file creation - // and force all initializers embed into the Onnx file - ini_size_threshold = SIZE_MAX; - force_embed_external_ini = true; - } - - ModelSavingOptions model_saving_options{ini_size_threshold}; - model_saving_options.force_embed_external_ini = force_embed_external_ini; + ONNX_NAMESPACE::ModelProto model_proto; + ORT_RETURN_IF_ERROR(EpContextModelToProto(ep_context_model, valid_output_model_path, ep_context_gen_options, + /*out*/ model_proto)); - if (saving_to_buffer) { - ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); - // TODO(adrianlizarraga): Investigate if we can make this more memory efficient. - // May be able to use allocator to directly allocate the ModelProto to avoid a copy. - ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, - context_cache_path, - model_saving_options); + if (output_buffer_holder != nullptr) { + // Write output model into a buffer ORT allocates for the user. size_t buffer_size = model_proto.ByteSizeLong(); ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), "Cannot serialize ONNX ModelProto larger than 2GB"); - AllocatorPtr allocator = ep_context_gen_options.output_model_buffer_allocator; + AllocatorPtr allocator = output_buffer_holder->buffer_allocator; IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_size); model_proto.SerializeToArray(buffer.get(), static_cast(buffer_size)); - *ep_context_gen_options.output_model_buffer_size_ptr = buffer_size; - *ep_context_gen_options.output_model_buffer_ptr = buffer.release(); + *output_buffer_holder->buffer_size_ptr = buffer_size; + *output_buffer_holder->buffer_ptr = buffer.release(); + } else if (output_write_func_holder != nullptr) { + // Write output model to user's output stream. + size_t buffer_size = model_proto.ByteSizeLong(); + ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), + "Cannot serialize ONNX ModelProto larger than 2GB"); + + auto out_stream_buf = std::make_unique(*output_write_func_holder); + std::ostream out_stream(out_stream_buf.get()); + + model_proto.SerializeToOstream(&out_stream); + out_stream.flush(); + ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus()); } else { - ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, context_cache_path, - external_ini_path, model_saving_options)); + // Write output model to a file. + int fd = 0; + Status status = Env::Default().FileOpenWr(valid_output_model_path, fd); + ORT_RETURN_IF_ERROR(status); + + ORT_TRY { + google::protobuf::io::FileOutputStream output(fd); + bool serialize_result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); + if (!serialize_result) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_PROTOBUF, + "Protobuf serialization failed when generating EPContext model ", + valid_output_model_path); + } + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, ex.what()); + }); + } + if (!status.IsOK()) { + GSL_SUPPRESS(es .84) + ORT_IGNORE_RETURN_VALUE(Env::Default().FileClose(fd)); + return status; + } + ORT_RETURN_IF_ERROR(Env::Default().FileClose(fd)); } return Status::OK(); @@ -1221,7 +1249,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const ConfigOptions& config_options, const logging::Logger& logger, Mode mode, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. @@ -1268,12 +1296,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - if (ep_context_gen_options.enable && ep_context_gen_options.output_model_buffer_ptr == nullptr) { - // Check before EP compile graphs - std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), - context_cache_path, - ep_context_gen_options.error_if_output_file_exists)); + if (ep_context_gen_options.enable) { + if (const std::filesystem::path* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); + output_model_path_ptr != nullptr) { + // Check before EP compile graphs + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(*output_model_path_ptr, graph.ModelPath(), + context_cache_path, + ep_context_gen_options.error_if_output_file_exists)); + } } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 6e36d79701fd7..abe46cea58ab2 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -15,7 +15,10 @@ class ExecutionProviders; class KernelRegistryManager; class Model; struct ConfigOptions; -struct EpContextModelGenerationOptions; + +namespace epctx { +struct ModelGenOptions; +} class GraphPartitioner { public: @@ -50,7 +53,7 @@ class GraphPartitioner { const ConfigOptions& config_options, const logging::Logger& logger, Mode mode = Mode::kNormal, - const EpContextModelGenerationOptions& ep_context_gen_options = {}, + const epctx::ModelGenOptions& ep_context_gen_options = {}, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; bool IsLoadCancellationFlagSet() const { diff --git a/onnxruntime/core/framework/plugin_data_transfer.cc b/onnxruntime/core/framework/plugin_data_transfer.cc index f753f00206c5d..d6b1680176815 100644 --- a/onnxruntime/core/framework/plugin_data_transfer.cc +++ b/onnxruntime/core/framework/plugin_data_transfer.cc @@ -41,7 +41,7 @@ Status DataTransfer::CopyTensors(const std::vector& src_dst_pairs) c for (size_t i = 0; i < src_dst_pairs.size(); ++i) { src_values.push_back(&values[i * 2]); dst_values.push_back(&values[i * 2 + 1]); - streams.push_back(nullptr); // static_cast(src_dst_pairs[i].src_stream)); + streams.push_back(reinterpret_cast(src_dst_pairs[i].src_stream)); } auto* status = impl_.CopyTensors(&impl_, src_values.data(), dst_values.data(), streams.data(), diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc index 231eb47603838..63f928d52d788 100644 --- a/onnxruntime/core/framework/session_options.cc +++ b/onnxruntime/core/framework/session_options.cc @@ -99,20 +99,11 @@ void SessionOptions::AddCustomOpLibraryHandle(PathString library_name, void* lib } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -EpContextModelGenerationOptions::EpContextModelGenerationOptions(const ConfigOptions& config_options) { - enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - output_model_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - output_external_initializers_file_path = config_options.GetConfigOrDefault( - kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); - output_external_initializer_size_threshold = 0; - embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; -} - -EpContextModelGenerationOptions SessionOptions::GetEpContextGenerationOptions() const { +epctx::ModelGenOptions SessionOptions::GetEpContextGenerationOptions() const { if (this->has_explicit_ep_context_gen_options) { return this->ep_context_gen_options; } - return EpContextModelGenerationOptions(this->config_options); + return epctx::ModelGenOptions(this->config_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b75eeb217e7f0..b328fc916f885 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -13,6 +13,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "core/framework/config_options.h" +#include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" @@ -70,53 +71,6 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; -/// -/// Options that configure the generation of a compiled model (i.e., a model with EPContext nodes). -/// There are two ways to compile a model: -/// 1. By specifying the correct session option configurations and creating an inference session. -/// The compiled model is generated as a side-effect of session creation. -/// 2. Using an explicit compile API (see OrtCompileApi struct in onnxruntime_c_api.h). -/// -/// The default values in this struct are set to match the current/default behavior of approach 1 to maintain -/// compatibility with the older way of compiling. The explicit compile API overrides some of these values to -/// provide its own defaults (see core/session/model_compilation_options.h/cc). -/// -struct EpContextModelGenerationOptions { - // Action to take if the output model does not have compiled (EPContext) nodes. - enum class ActionIfNoCompiledNodes { - // Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior - // to maintain compatibility. The explicit compile API does *not* use this action. - kDontGenerateModel = 0, - - // Generate an output model even if it doesn't have compiled nodes. - // The explicit Compile API defaults to this value. - kGenerateModel, - - // Return an error if the model does not have compiled nodes. - // The explicit Compile API can be configured to this value. - kReturnError, - }; - - EpContextModelGenerationOptions() = default; - - // Initializes from string key/value pairs in session config options. - // This initializes this struct from options set via the older compiling approach #1 above. - explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); - - bool enable = false; - bool error_if_output_file_exists = true; - ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel; - bool embed_ep_context_in_model = false; - - std::string output_model_file_path; - void** output_model_buffer_ptr = nullptr; - size_t* output_model_buffer_size_ptr = nullptr; - AllocatorPtr output_model_buffer_allocator = nullptr; - - std::string output_external_initializers_file_path; - size_t output_external_initializer_size_threshold = 0; -}; - struct EpSelectionPolicy { // flag to detect that a policy was set by the user. // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered @@ -270,8 +224,8 @@ struct SessionOptions { // The function GetEpContextGenerationOptions() handles conversion of string key/value pairs to the new // struct type. bool has_explicit_ep_context_gen_options = false; - EpContextModelGenerationOptions ep_context_gen_options = {}; - EpContextModelGenerationOptions GetEpContextGenerationOptions() const; + epctx::ModelGenOptions ep_context_gen_options = {}; + epctx::ModelGenOptions GetEpContextGenerationOptions() const; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/tensor_external_data_info.cc b/onnxruntime/core/framework/tensor_external_data_info.cc index d7f5b23d56c70..dfdb3ba962609 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.cc +++ b/onnxruntime/core/framework/tensor_external_data_info.cc @@ -18,6 +18,13 @@ using ::google::protobuf::RepeatedPtrField; using ::ONNX_NAMESPACE::StringStringEntryProto; namespace onnxruntime { +ExternalDataInfo::ExternalDataInfo() = default; + +#if !defined(ORT_MINIMAL_BUILD) +ExternalDataInfo::ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length) + : rel_path_(rel_path), offset_(offset), length_(length) {} +#endif + Status ExternalDataInfo::Create(const RepeatedPtrField& input, std::unique_ptr& external_data_info_result) { auto external_data_info = std::make_unique(); diff --git a/onnxruntime/core/framework/tensor_external_data_info.h b/onnxruntime/core/framework/tensor_external_data_info.h index 784b3f352a78e..aa9bb32922bd7 100644 --- a/onnxruntime/core/framework/tensor_external_data_info.h +++ b/onnxruntime/core/framework/tensor_external_data_info.h @@ -25,6 +25,12 @@ class ExternalDataInfo { using OFFSET_TYPE = off_t; #endif + ExternalDataInfo(); + +#if !defined(ORT_MINIMAL_BUILD) + ExternalDataInfo(const PathString& rel_path, OFFSET_TYPE offset, size_t length); +#endif + const PathString& GetRelPath() const { return rel_path_; } OFFSET_TYPE GetOffset() const { return offset_; } diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 2ef7c4a9091f3..c5d7d4cc4e68c 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -31,8 +31,10 @@ enum class OrtGraphIrApi { kEpApi, }; -// Alias OrtExternalInitializerInfo to the internal type. -struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {}; +// Alias OrtExternalInitializerInfo to the internal onnxruntime::ExternalDataInfo type. +struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo { + using onnxruntime::ExternalDataInfo::ExternalDataInfo; // inherit constructors +}; /// /// Public type that represents an ONNX value info. diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 46a52e042ba13..b48fe8c1e1839 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1387,33 +1387,49 @@ constexpr const char* MoE_ver1_doc = R"DOC( Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). + + The SwiGLU (Swish-Gated Linear Unit) activation function is like: + g = xW + b + l = xV + c + G = clamp(g, max=limit) + L = clamp(l, min=-limit, max=limit) + swiglu = G * sigmoid(alpha * G) * (L + beta) + where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters. + When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs. + When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size. + When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row. )DOC"; -ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, - OpSchema() - .SetDoc(MoE_ver1_doc) - .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) - .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) - .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) - .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) - .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") - .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) - .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") - .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) - .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) - .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) - .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +ONNX_MS_OPERATOR_SET_SCHEMA( + MoE, 1, + OpSchema() + .SetDoc(MoE_ver1_doc) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("swiglu_fusion", "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.", AttributeProto::INT, static_cast(0)) + .Attr("swiglu_limit", "The limit used to clamp in SwiGLU. No clamp when limit is not provided.", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("activation_alpha", "Alpha parameter used in activation function.", AttributeProto::FLOAT, 1.0f) + .Attr("activation_beta", "Beta parameter used in activation function.", AttributeProto::FLOAT, 0.0f) + .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) + .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu", "T") + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) + .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") + .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, inter_size, hidden_size)", "T", OpSchema::Optional) + .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA( QMoE, 1, OpSchema() .SetDoc("Quantized MoE") .Attr("activation_type", - "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", @@ -1429,6 +1445,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Number of bits used in quantized weights. Default is 4 bits", AttributeProto::INT, static_cast(4)) + .Attr("swiglu_fusion", "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.", AttributeProto::INT, static_cast(0)) + .Attr("swiglu_limit", "The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("activation_alpha", "Alpha parameter used in activation function.", AttributeProto::FLOAT, 1.0f) + .Attr("activation_beta", "Beta parameter used in activation function.", AttributeProto::FLOAT, 0.0f) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " @@ -1437,19 +1457,21 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", - "3D input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "3D input tensor with shape (num_experts, inter_size, hidden_size), " + "or (num_experts, inter_size, hidden_size / 2) for 4 bits. " + "For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), " + "or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.", "T1") - .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2") .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(5, "fc2_experts_weights", - "3D input tensor with shape (num_experts, inter_size, hidden_size) " - "or (num_experts, inter_size, hidden_size / 2)", + "3D input tensor with shape (num_experts, hidden_size, inter_size) " + "or (num_experts, hidden_size, inter_size / 2) for 4 bits", "T1") - .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2") .Input(7, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", @@ -1457,14 +1479,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(8, "fc3_experts_weights", - "3D optional input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "3D optional input tensor with shape (num_experts, inter_size, hidden_size) " + "or (num_experts, inter_size, hidden_size / 2)", "T1", OpSchema::Optional) .Input(9, "fc3_scales", "2D optional input tensor with shape (num_experts, inter_size)", - "T", + "T2", OpSchema::Optional) .Input(10, "fc3_experts_bias", @@ -1476,8 +1498,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " "(batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain scales type to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 0d9b93631ee8a..92eb31f0ad385 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -327,6 +327,9 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { + if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional + continue; + } if (input_value_infos[i]->GetName() == value_info_name) { indices.push_back(is_implicit ? -1 : static_cast(i)); found = true; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 0a228176175eb..9a97711996343 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -17,6 +17,7 @@ #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/error_code_helper.h" #include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/tensor_external_data_info.h" @@ -4357,14 +4358,23 @@ Status Graph::RegenerateInitializersAndReplaceInMemory(gsl::span& subgraphs) { + for (const auto& node : nodes) { if (node.ContainsSubgraph()) { // Let's find this node in the output_graph_proto // The node name is optional, so we may need to check by the output value name // given that they can only assigned once. - auto hit = std::find_if(output_graph_proto.mutable_node()->begin(), - output_graph_proto.mutable_node()->end(), + auto hit = std::find_if(graph_proto.mutable_node()->begin(), + graph_proto.mutable_node()->end(), [&node](const ONNX_NAMESPACE::NodeProto& proto) { const auto& node_name = node.Name(); if (!node_name.empty()) @@ -4372,7 +4382,7 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr return (proto.output_size() > 0 && proto.output(0) == node.OutputDefs()[0]->Name()); }); - ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(), + ORT_RETURN_IF_NOT(hit != graph_proto.mutable_node()->end(), "Node ", node.Name(), " not found in output_graph_proto"); auto& result_node = *hit; for (const auto& e : node.GetAttributeNameToSubgraphMap()) { @@ -4387,12 +4397,28 @@ Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_gr ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit), "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ", node.Name(), " while attempting to recurse into it."); - auto& result_subgraph = *sub_hit->mutable_g(); - ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(result_subgraph)); + SubgraphWithMutableProto subgraph_result{sub_hit->mutable_g(), subgraph}; + subgraphs.emplace_back(subgraph_result); } } } + return Status::OK(); +} + +Status Graph::ProcessSubgraphsInMemoryData(ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + // Process subgraphs recursively (bottom-up). + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->ProcessSubgraphsInMemoryData(*subgraph_proto)); + } + } + // Filter in iterators for weights that are present in the name_to_initial_tensor_ map // and preserve the order. This is needed for tests. InlinedVector initializers_to_process; @@ -4444,44 +4470,19 @@ Status Graph::AddExternalInitializersToGraphProtoImpl( // Process initializers in a subgraph, check their size and // write to an external file. This function also saves pre-packed // blobs for the initializer being saved to disk, if the initializer has any pre-packs. - // This function is invoked by ToGraphProtoWithExternalInitiallizers() and processes subgraphs + // This function is invoked by ToGraphProtoWithExternalInitializers() and processes subgraphs // bottom up. - for (const auto& node : Nodes()) { - if (node.ContainsSubgraph()) { - // Let's find this node in the output_graph_proto - // The node name is optional, so we may need to check by the output value name - // given that they can only assigned once. - auto hit = std::find_if(output_graph_proto.mutable_node()->begin(), - output_graph_proto.mutable_node()->end(), - [&node](const ONNX_NAMESPACE::NodeProto& proto) { - const auto& node_name = node.Name(); - if (!node_name.empty()) - return proto.name() == node_name; - return (proto.output_size() > 0 && - proto.output(0) == node.OutputDefs()[0]->Name()); - }); - ORT_RETURN_IF_NOT(hit != output_graph_proto.mutable_node()->end(), "Node ", node.Name(), - " not found in output_graph_proto"); - auto& result_node = *hit; - for (const auto& e : node.GetAttributeNameToSubgraphMap()) { - const auto& name = e.first; - const auto& subgraph = e.second; - // Lets find this subgraph in the result_node - auto sub_hit = std::find_if(result_node.mutable_attribute()->begin(), - result_node.mutable_attribute()->end(), - [&name](const ONNX_NAMESPACE::AttributeProto& proto) { - return proto.name() == name; - }); - ORT_RETURN_IF_NOT(sub_hit != result_node.mutable_attribute()->end() && utils::HasGraph(*sub_hit), - "Subgraph ", name, " is referred to in GetAttributeNameToSubgraphMap, but not found in node ", - node.Name(), " while attempting to recurse into it."); - auto& result_subgraph = *sub_hit->mutable_g(); - ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl( - model_path, external_file_path, - model_external_file_path, model_saving_options, - result_subgraph, - external_stream, external_offset)); - } + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->AddExternalInitializersToGraphProtoImpl( + model_path, external_file_path, + model_external_file_path, model_saving_options, + *subgraph_proto, external_stream, external_offset)); } } @@ -4643,6 +4644,113 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers( return result; } +Status Graph::ToGraphProtoWithCustomInitializerHandlingImpl( + OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const { + // This loop processes subgraphs bottom up. + { + std::vector subgraphs; + ORT_RETURN_IF_ERROR(GetSubgraphsWithMatchingGraphProtos(Nodes(), output_graph_proto, subgraphs)); + + for (SubgraphWithMutableProto& subgraph_and_proto : subgraphs) { + gsl::not_null subgraph = subgraph_and_proto.subgraph; + gsl::not_null subgraph_proto = subgraph_and_proto.subgraph_proto; + ORT_RETURN_IF_ERROR(subgraph->ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, + state, *subgraph_proto)); + } + } + + // Create a sorted std::vector of initializers so that we always process them in a deterministic order. + InlinedVector initializers; + initializers.reserve(GetAllInitializedTensors().size()); + + for (const auto& [name, initializer_tp] : GetAllInitializedTensors()) { + initializers.push_back(initializer_tp); + } + + std::sort(initializers.begin(), initializers.end(), + [](const ONNX_NAMESPACE::TensorProto* a, const ONNX_NAMESPACE::TensorProto* b) { + return a->name() < b->name(); + }); + + // Call user's handler function for each initializer. We store the initializer externally + // or within the model depending on the result returned by the handler function. + for (gsl::not_null initializer : initializers) { +#if !defined(DISABLE_SPARSE_TENSORS) + if (IsSparseInitializer(initializer->name())) { + // Sparse tensors are added to the ONNX file directly. + auto& sparse_initializer = *output_graph_proto.add_sparse_initializer(); + ORT_RETURN_IF_ERROR(utils::DenseTensorToSparseTensorProto(*initializer, ModelPath(), sparse_initializer)); + } else { +#endif + TensorProto* output_proto = output_graph_proto.add_initializer(); + + output_proto->set_name(initializer->name()); + output_proto->set_data_type(initializer->data_type()); + for (int i = 0; i != initializer->dims_size(); ++i) { + output_proto->add_dims(initializer->dims(i)); + } + output_proto->set_doc_string(initializer->doc_string()); + + OrtValue ort_value; + std::unique_ptr original_ext_data_info = nullptr; + + if (utils::HasExternalDataInFile(*initializer)) { + // Initializer has data in an external file. Load it into OrtValue (potentially via memory mapping). + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(initializer->external_data(), original_ext_data_info)); + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(Env::Default(), ModelPath(), *initializer, ort_value)); + } else { + // Initializer is either stored inline within the TensorProto or it is "external data in memory". + // Get an OrtValue (if already loaded by Graph) or copy into an OrtValue otherwise. + bool graph_has_ort_value = GetOrtValueInitializer(initializer->name(), ort_value, /*check_outer_scope*/ false); + if (!graph_has_ort_value) { + assert(!utils::HasExternalData(*initializer)); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), ModelPath(), *initializer, + CPUAllocator::DefaultInstance(), ort_value)); + } + } + + // Call the user's initializer handler function. If the user wants to store the initializer externally, + // the handler function will use OrtApi::CreateExternalInitializerInfo() to create a new + // OrtExternalInitializerInfo instance that indicates the location of the data. + OrtExternalInitializerInfo* new_external_info = nullptr; + Status status = ToStatusAndRelease(handle_initializer_func(state, initializer->name().c_str(), + &ort_value, + static_cast(original_ext_data_info.get()), + &new_external_info)); + + ORT_RETURN_IF(new_external_info != nullptr && + new_external_info == static_cast(original_ext_data_info.get()), + "User's OrtGetInitializerLocationFunc must not return the external_info parameter.", + "Return a copy instead."); + std::unique_ptr new_external_info_holder(new_external_info); // Take ownership + ORT_RETURN_IF_ERROR(status); + + if (new_external_info != nullptr) { + ExternalDataInfo::SetExternalLocationToProto(new_external_info->GetRelPath(), new_external_info->GetOffset(), + new_external_info->GetLength(), *output_proto); + } else { + const Tensor& tensor = ort_value.Get(); + output_proto->clear_data_location(); + utils::SetRawDataInTensorProto(*output_proto, tensor.DataRaw(), tensor.SizeInBytes()); + } +#if !defined(DISABLE_SPARSE_TENSORS) + } +#endif + } + + return Status::OK(); +} + +Status Graph::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const { + ToGraphProtoInternal(graph_proto); + ORT_RETURN_IF_ERROR(ToGraphProtoWithCustomInitializerHandlingImpl(handle_initializer_func, state, graph_proto)); + return Status::OK(); +} + void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const { graph_proto_->clear_node(); graph_proto_->clear_input(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index eb5e1e89e2f9c..0ffbced51ee35 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -415,6 +415,25 @@ ModelProto Model::ToGraphProtoWithExternalInitializers(const std::filesystem::pa return result; } +common::Status Model::ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const { + model_proto = model_proto_; + + // Sync current model_metadata_ back to protobuf metadata_props + model_proto.clear_metadata_props(); + for (const auto& metadata : model_metadata_) { + const gsl::not_null prop{model_proto.add_metadata_props()}; + prop->set_key(metadata.first); + prop->set_value(metadata.second); + } + + const auto& graph = *graph_; + ORT_RETURN_IF_ERROR(graph.ToGraphProtoWithCustomInitializerHandling(handle_initializer_func, + state, *model_proto.mutable_graph())); + return Status::OK(); +} + Status Model::Load(std::istream& model_istream, ModelProto* p_model_proto) { if (!model_istream.good()) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid istream object."); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index e8722f6f5c0b2..c86aac44806bd 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -210,6 +210,18 @@ class Model { const std::filesystem::path& file_path, const ModelSavingOptions& model_saving_options) const; + /// + /// Serialize the Model to a onnx::ModelProto. Caller provides a function that determines where each initializer + /// is stored (i.e., either in an external file or within the model). + /// + /// Function called for every initializer. + /// Opaque user state passed to the handle_initializer_func. + /// Output parameter set to the serialized onnx::ModelProto. + /// A status indicating success or an error. + common::Status ToGraphProtoWithCustomInitializerHandling(OrtGetInitializerLocationFunc handle_initializer_func, + void* state, + /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) const; + static common::Status Save(Model& model, const PathString& file_path); static common::Status Save(Model& model, int fd); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 3627989609737..fc3c0b6016ced 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -48,6 +48,17 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + + /// + /// Address of scale * accumulate(quant - zp), one per block, where `scale`, `quant`, `zp` are respectively + /// an individual block's scale, quantized values, and zero point for the input `B`. + /// When converting the activation input (A) to uint8, we first convert the values to int8 and then + /// add a "bias" of +128 to convert the range of values from [-128, +127] to [0, +255]. + /// This input helps to "de-bias" the output of the +128 bias added to the activation input. + /// This input is to be used only when A is quantized to uint8. + /// + const T* BlkUnsignedQuantAZeroPointCorrection = nullptr; + const T* Bias = nullptr; ///< optional address of Bias, vector size N T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index a099bcf8438fe..2bbcfd51fe4ba 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1200,7 +1200,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH; const MLAS_QNBIT_GEMM_DISPATCH& GetMlasQNBitGemmDispatchNeon( - bool InitializeWithDotSupport + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport ); extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; @@ -1297,6 +1298,8 @@ struct MLAS_PLATFORM { // TODO: move to cpuinfo bool Avx2Supported_ = false; bool Avx512Supported_ = false; + bool ArmNeonIsQuantActivationsUnsigned = false; + // Mlas overrides initialisation MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3256dadb856d3..c4b8d5e78a491 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -582,7 +582,6 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } - this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions); #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; @@ -593,16 +592,22 @@ Return Value: } #endif -#if defined(__linux__) // // Check if the processor supports ASIMD I8MM instructions. // - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + + const bool HasI8MMInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM(); + if (HasI8MMInstructions) { +#if defined(__linux__) + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; - } #endif + } + + this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true; + this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions); #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.cpp b/onnxruntime/core/mlas/lib/qnbitgemm.cpp index 19d11a60b7376..d806f4b08bfff 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm.cpp @@ -132,7 +132,7 @@ QNBitGemmPerGemmWorkspaceSize( } if (BlkBitWidth == 4 || BlkBitWidth == 8) { - return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType); + return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType, BlkBitWidth); } return 0; @@ -266,7 +266,7 @@ MlasQNBitGemmPackQuantBData( if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen, false); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -307,7 +307,8 @@ MlasQNBitGemmPackQuantBData( } else if (BlkBitWidth == 8) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, + BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum( N, K, @@ -742,6 +743,8 @@ SQ8BitGemm_CompInt8( : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + const float* BlkUnsignedQuantAZeroPointCorrection = + DataParams->BlkUnsignedQuantAZeroPointCorrection ? DataParams->BlkUnsignedQuantAZeroPointCorrection + RangeStartN * k_blks : nullptr; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; @@ -759,6 +762,8 @@ SQ8BitGemm_CompInt8( if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; + const float* blk_unsigned_quant_A_zp_correction = BlkUnsignedQuantAZeroPointCorrection ? + BlkUnsignedQuantAZeroPointCorrection + n * k_blks : nullptr; GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, @@ -774,7 +779,8 @@ SQ8BitGemm_CompInt8( bias, ldc, ABlockSum, - b_blk_sum + b_blk_sum, + blk_unsigned_quant_A_zp_correction ); if (DataParams->PostProcessor != nullptr) { @@ -798,7 +804,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ); template <> @@ -812,7 +819,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(N); @@ -826,7 +834,7 @@ InitializeWorkspace_CompInt8( const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; @@ -834,38 +842,63 @@ InitializeWorkspace_CompInt8( std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr); }); - } else if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); } else { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - - void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); - std::byte* QuantARowPtr = quant_a_data.QuantData; - float* QuantARowScalePtr = quant_a_data.QuantScale; - float* QuantARowBlkSum = quant_a_data.BlockSum; - for (size_t m = 0; m < M; ++m) { - QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); - ARowPtr += data.lda; - QuantARowPtr += BlockCountK * BlkLen; - QuantARowScalePtr += BlockCountK; - QuantARowBlkSum += BlockCountK; + // TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms + if (BlkBitWidth == 4) { + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); } - }); - } + } else if (BlkBitWidth == 8) { + if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } + } + } } template <> @@ -879,7 +912,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(M); MLAS_UNREFERENCED_PARAMETER(N); @@ -890,6 +924,7 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(Workspace); MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); MLAS_UNREFERENCED_PARAMETER(ThreadPool); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); } template @@ -902,7 +937,8 @@ using InitializeWorkspaceFn = std::function* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth )>; template @@ -1015,7 +1051,7 @@ MlasQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool, BlkBitWidth ); } @@ -1029,17 +1065,19 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -1107,7 +1145,7 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; @@ -1115,10 +1153,11 @@ MlasQNBitGemmBatch( PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 4c133103bee04..06e8e49b59e2e 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -48,24 +48,39 @@ MlasAlignAddress(void* addr, const size_t alignment) template struct PackedQuantBDataStruct { - PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen, bool QuantAUnsigned) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); + if constexpr (BlkBitWidth == 8) { + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + } else { #if defined(MLAS_TARGET_AMD64_IX86) // avx512 requires alignment on a 64-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); #else PackedQuantBData = (std::byte*)PackedQuantBWorkspace; #endif + } + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + + if (QuantAUnsigned) { + BlkUnsignedQuantAZeroPointCorrection = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + BlkUnsignedQuantAZeroPointCorrection = (T*)MlasAlignAddress(BlkUnsignedQuantAZeroPointCorrection, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)BlkUnsignedQuantAZeroPointCorrection + BlkSumSize); + } else { + BlkUnsignedQuantAZeroPointCorrection = nullptr; + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + } } + std::byte* PackedQuantBData; T* PackedQuantBScale; T* QuantBBlkSum; + T* BlkUnsignedQuantAZeroPointCorrection; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -178,7 +193,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH { size_t K, size_t BlkLen, bool HasZeroPoint, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth ); QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; @@ -373,20 +389,22 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. * A and B are block quantized and B is column major. * - * @param BlkLen Number of values in a block. - * @param QuantA Supplies the quantized A matrix. - Binary data containing block quantized int8 data and scale values. - * @param QuantBData Supplies the quantized B matrix block data. - * @param QuantBScale Supplies the quantized B matrix block scale values. - * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. - * @param[out] C Supplies the output C matrix. - * @param CountN Number of columns of B and C. - * @param CountK Number of columns of A and rows of B. - * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. - * @param Bias Bias vector of length N. - * @param ldc Number of elements between adjacent rows of C.. - * @param ABlockSum Supplies the blksum of A. - * @param QuantBBlkSum Supplies the blksum of B. + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + * @param BlkUnsignedQuantAZeroPointCorrection Supplies the optional input to de-bias the Gemm output to account for the +128 bias + addition when the activation input A is quantized to uint8. */ typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( size_t BlkLen, @@ -403,7 +421,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH { const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection ); SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp index 0d06eb04e5245..ba2b68e4fbb07 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp @@ -21,6 +21,7 @@ Module Name: #include #include +#include #include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" @@ -42,8 +43,9 @@ namespace // Quantized B data packing function implementation. // +template size_t -Q4BitGemmPackQuantBDataSize( +QNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, @@ -51,26 +53,49 @@ Q4BitGemmPackQuantBDataSize( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { + if constexpr (BlkBitWidth == 4) { #ifndef USE_KLEIDIAI - MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType #endif #ifdef USE_KLEIDIAI - if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { - const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); - } else + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + } else #endif - { - constexpr size_t BlkBitWidth = 4; - + { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } + } else { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if (ComputeType == SQNBIT_CompInt8) { + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // align on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + if constexpr (QuantAUnsigned) { + // 2 block sum + return PackedQuantBDataSize + ScaleSize + BlkSumSize + BlkSumSize; + } else { + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } + } else { + return PackedQuantBDataSize; + } } } @@ -199,6 +224,167 @@ SQ4BitGemmPackQuantBDataAndBlkSum( } } +void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + float* BlkUnsignedQuantAZeroPointCorrectionBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t K, + const size_t BlkLen) +{ + constexpr size_t SubBlkLen = 4; + const size_t BlkCountK = MlasDivRoundup(K, BlkLen); + const size_t SubBlkPerBlk = BlkLen / SubBlkLen; + const size_t StrideN = BlkCountK * BlkLen; + const size_t Iterations = N * BlkCountK; + + // 4 rows x 8 columns pack together, then 4 rows x 4 columns, then per column. + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / BlkCountK; + const size_t c8 = c & (~7), c8_res = c & 7; + const size_t c4 = c & (~3), c4_res = c & 3; + const size_t r_blk = tid % BlkCountK; + size_t r_subblk = r_blk * SubBlkPerBlk; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_blk * BlkLen; + const uint8_t* src8 = reinterpret_cast(src); + + for (size_t i = 0; i < SubBlkPerBlk; ++i, src += SubBlkLen, ++r_subblk) { + if (c8 + 8 <= N) { // full 8 cols + std::byte* dest = + PackedQuantBDataBegin + c8 * StrideN + r_subblk * SubBlkLen * 8 + c8_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else if (c4 + 4 <= N) { // full 4 cols + std::byte* dest = + PackedQuantBDataBegin + c4 * StrideN + r_subblk * SubBlkLen * 4 + c4_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } + } + + if (BlkUnsignedQuantAZeroPointCorrectionBegin) { + const int accu = std::accumulate(src8, src8 + std::min(BlkLen, K - r_blk * BlkLen), 0); + + // for sgemmc + const size_t dst_offset = ((c / 16) * BlkCountK + r_blk) * 16 + c % 16; + BlkUnsignedQuantAZeroPointCorrectionBegin[dst_offset] = static_cast(accu); + } + } + ); +} + +void +Q8ComputePackBlkSum( + const size_t BlkLen, + const size_t N, + const size_t K, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + float* BlockSum2Begin, + MLAS_THREADPOOL* ThreadPool) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n8 = n & (~7), n8_res = n & 7; + const size_t n4 = n & (~3), n4_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlockSum2Begin) { + BlockSum2Begin[dst_offset] = QuantBScale * (static_cast(zp) * std::min(BlkLen, K - k_blk * BlkLen) - BlockSum2Begin[dst_offset]); + } + + // re-arrange scale to the same order as packed data + if (n4 + 4 > N) { // remainder cols + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (n8 + 8 > N) { // full 4 cols + *(QuantBScaleBegin + n4 * BlockCountK + k_blk * 4 + n4_res) = QuantBScale; + } else { // full 8 cols + *(QuantBScaleBegin + n8 * BlockCountK + k_blk * 8 + n8_res) = QuantBScale; + } + }); +} + +/** + * 4 rows x 8 cols pack together, along all K. Then 4 rows x 4 cols, along all K. + * When rows < 4, keep original layout. + * + * dotprod: vdotq_laneq_u32. + * convert quant a from int8 to uint8. zp is 128. + * + * i8mm: vusdotq_laneq_s32. + */ +void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType */, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // Pack the quantized weights + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool, N, K, BlkLen); + } else { + // We ignore the scales and zero points if they are provided when pre-packing the weights as there is + // some "state" associated with 'BlkUnsignedQuantAZeroPointCorrection'. + + // We accumulate the block sum into 'BlkUnsignedQuantAZeroPointCorrection' while packing the weights + // in the previous step. If we were to use 'scales' while pre-packing the weights and if there were no + // zero points, then we would enter 'Q8ComputePackBlkSum' twice - once while pre-packing the weights + // and once while pre-packing the scales which would lead to erroneous 'BlkUnsignedQuantAZeroPointCorrection' + // computation as the buffer is "used" in-place for the "block sum" temporary values (obtained while pre-packing + // the weights) and the actual 'BlkUnsignedQuantAZeroPointCorrection' which will use the scales. + // Hence, to ensure that the piece of logic to calculate 'BlkUnsignedQuantAZeroPointCorrection' is only invoked + // once, we do it while we are pre-packing the scales and ignore any provided 'scales' and 'zero points' while + // pre-packing the weights. + // The flip side is that the user has to ensure that this function is called once each for 'weights', + // 'scales', and 'zero points'. This is a reasonable expectation and hence we go with that design. + + // Pack the block scales + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + // Pack the blksum (and BlkUnsignedQuantAZeroPointCorrection if applicable) + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, N, K, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool); + } + } +} + // // Workspace size calculation function implementation. // @@ -210,19 +396,21 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkLen, bool HasZeroPoint, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(N); #ifndef USE_KLEIDIAI MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); #endif switch (ComputeType) { case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 #ifdef USE_KLEIDIAI - if (UseKleidiAI(K, BlkLen, HasZeroPoint)) { + if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = M == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); @@ -233,8 +421,10 @@ QNBitGemmPerGemmWorkspaceSize( } else #endif { + // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } } @@ -278,6 +468,77 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) #endif } +template +size_t +SQ8BitGemmKernel_BlkSum_CompInt8( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection +) +{ + MlasQ8Int8GemmKernelNeon( + BlkLen, + reinterpret_cast*>(QuantA), + QuantAScale, + reinterpret_cast(QuantBData), + QuantBScale, + C, + CountM, + CountN, + CountK, + Bias, + ldc + ); + + { + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + + if constexpr (QuantAUnsigned) { + { + assert(BlkUnsignedQuantAZeroPointCorrection != nullptr); + float* c_blk = C; + const float* b_blk_sum2 = BlkUnsignedQuantAZeroPointCorrection; + + size_t RowsRemaining = CountM; + const float* a_scale_row = QuantAScale; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_scale_row, b_blk_sum2, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 128.f); + + c_blk += ldc * RowsHandled; + a_scale_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + + return CountM; +} + } // namespace sqnbitgemm_neon // @@ -286,7 +547,8 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) const MLAS_QNBIT_GEMM_DISPATCH& GetMlasQNBitGemmDispatchNeon( - bool InitializeWithDotSupport + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport ) { // Note: The InitializeWithX parameters are only used in the invocation of this method that initializes the static @@ -295,9 +557,11 @@ GetMlasQNBitGemmDispatchNeon( static const MLAS_QNBIT_GEMM_DISPATCH MlasQNBitGemmDispatchNeon = [&]() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<4, false>; + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, true>; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ8BitGemmPackQuantBDataAndBlkSum; d.QNBitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceSize; d.QNBitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceAlignment; @@ -310,12 +574,21 @@ GetMlasQNBitGemmDispatchNeon( d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; d.UsePacked_CompInt8 = sqnbitgemm_neon::UsePacked_CompInt8; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + #ifdef USE_KLEIDIAI d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; #endif } + if (InitializeWithI8MMSupport) { + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, false>; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + } + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; diff --git a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h index a254ec9f92596..c8be42b01fbe2 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h @@ -123,6 +123,36 @@ QuantizeARow_CompInt8( std::byte* QuantA ); +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +using QuantAType = typename std::conditional::type; + +template +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const QuantAType* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +); + size_t SQ4BitGemmKernel_CompInt8( size_t BlkLen, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 384b04c807195..f160c9f541238 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -602,7 +602,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx2( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index c1bc00fbffa3e..122086d8ef05b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -264,7 +264,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx512( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index ea5eebd854655..e172308637af1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -316,7 +316,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index bb38f37fb0eb8..36c15cd5ac57f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -469,7 +469,8 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkLen, bool /* HasZeroPoint */, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t /* BlkBitWidth */ ) { MLAS_UNREFERENCED_PARAMETER(N); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index 8dbd339468930..b03b8121059f3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -187,6 +187,230 @@ QuantizeARow_CompInt8( } } +MLAS_FORCEINLINE +float32x4_t LoadFloat32x4(const float* src, size_t count) +{ + if (count == 4) { + return vld1q_f32(src); + } else if (count == 3) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + v = vld1q_lane_f32(src + 2, v, 2); + return v; + } else if (count == 2) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + return v; + } else { + assert(count == 1); + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + return v; + } +} + +template +using I16VecType = typename std::conditional::type; + +template +I16VecType MLAS_FORCEINLINE +PrepareZeroI16() +{ + if constexpr (IsQuantAUnsigned) { + return vdupq_n_u16(0); + } else { + return vdupq_n_s16(0); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +) +{ + // First use i8 to quantize A. range [-128, 127] + // If convert to u8, +128. Range [0, 255] + assert(BlkLen % 16 == 0); + assert(BlkLen <= 256); + MLAS_DECLSPEC_ALIGN(static const uint8_t MASK[16], 16) = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + const int16x8_t v128 = vdupq_n_s16(128); + QuantAType* blob = reinterpret_cast*>(QuantA); + float* scale_ptr = QuantAScale; + size_t k = 0; + for (; k + BlkLen <= CountK; k += BlkLen) { + float32x4_t absMax0 = vdupq_n_f32(0.0f); + float32x4_t absMax1 = vdupq_n_f32(0.0f); + float32x4_t absMax2 = vdupq_n_f32(0.0f); + float32x4_t absMax3 = vdupq_n_f32(0.0f); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4x4_t v0 = vld4q_f32(A + k + kk); + absMax0 = vmaxq_f32(absMax0, vabsq_f32(v0.val[0])); + absMax1 = vmaxq_f32(absMax1, vabsq_f32(v0.val[1])); + absMax2 = vmaxq_f32(absMax2, vabsq_f32(v0.val[2])); + absMax3 = vmaxq_f32(absMax3, vabsq_f32(v0.val[3])); + } + + const float32x4_t max01 = vmaxq_f32(absMax0, absMax1); + const float32x4_t max23 = vmaxq_f32(absMax2, absMax3); + const float32x4_t max0123 = vmaxq_f32(max01, max23); + const float maxScalar = vmaxvq_f32(max0123); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + scale_ptr++; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16_0 = PrepareZeroI16(); + I16VecType sum_8_i16_1 = PrepareZeroI16(); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4_t vfp32_0 = LoadFloat32x4(A + k + kk, 4); + const float32x4_t vfp32_1 = LoadFloat32x4(A + k + kk + 4, 4); + const float32x4_t vfp32_2 = LoadFloat32x4(A + k + kk + 8, 4); + const float32x4_t vfp32_3 = LoadFloat32x4(A + k + kk + 12, 4); + + const float32x4_t v0 = vmulq_f32(vfp32_0, mul); + const float32x4_t v1 = vmulq_f32(vfp32_1, mul); + const float32x4_t v2 = vmulq_f32(vfp32_2, mul); + const float32x4_t v3 = vmulq_f32(vfp32_3, mul); + + const int32x4_t i0 = vcvtnq_s32_f32(v0); + const int32x4_t i1 = vcvtnq_s32_f32(v1); + const int32x4_t i2 = vcvtnq_s32_f32(v2); + const int32x4_t i3 = vcvtnq_s32_f32(v3); + + const int16x8_t v_8_i16_0 = vcombine_s16(vqmovn_s32(i0), vqmovn_s32(i1)); + const int16x8_t v_8_i16_1 = vcombine_s16(vqmovn_s32(i2), vqmovn_s32(i3)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16_0 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_0, v128)); + const uint16x8_t v_8_u16_1 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_1, v128)); + const uint8x16_t v_16_u8 = vcombine_u8(vqmovn_u16(v_8_u16_0), vqmovn_u16(v_8_u16_1)); + vst1q_u8(blob + k + kk, v_16_u8); + + // accumulate Sum(a_i) + const uint16x8_t i_8_u16_0 = vmovl_u8(vget_low_u8(v_16_u8)); + const uint16x8_t i_8_u16_1 = vmovl_high_u8(v_16_u8); + sum_8_i16_0 = vaddq_u16(sum_8_i16_0, i_8_u16_0); + sum_8_i16_1 = vaddq_u16(sum_8_i16_1, i_8_u16_1); + } else { + const int8x16_t v_16_i8 = vcombine_s8(vqmovn_s16(v_8_i16_0), vqmovn_s16(v_8_i16_1)); + vst1q_s8(blob + k + kk, v_16_i8); + + // accumulate Sum(a_i) + sum_8_i16_0 = vaddq_s16(sum_8_i16_0, v_8_i16_0); + sum_8_i16_1 = vaddq_s16(sum_8_i16_1, v_8_i16_1); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t sum_8_u16 = vaddq_u16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_u16(sum_8_u16)); + } else { + const int16x8_t sum_8_i16 = vaddq_s16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + AScaledBlkSum++; + } + + if (k < CountK) { + float32x4_t absMax = vdupq_n_f32(0.0f); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t v0 = LoadFloat32x4(A + kk, step); + absMax = vmaxq_f32(absMax, vabsq_f32(v0)); + } + + const float maxScalar = vmaxvq_f32(absMax); + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16 = PrepareZeroI16(); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t vfp32 = LoadFloat32x4(A + kk, step); + const float32x4_t v_f32 = vmulq_f32(vfp32, mul); + const int32x4_t v_i32 = vcvtnq_s32_f32(v_f32); + const int16x8_t v_8_i16 = vcombine_s16(vqmovn_s32(v_i32), vdup_n_s16(0)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16, v128)); + uint8x8_t v_8_u8 = vqmovn_u16(v_8_u16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_u8(v_8_u8), 0); + + // accumulate Sum(a_i) + v_8_u8 = vand_u8(v_8_u8, vld1_u8(MASK + 8 - step)); + const uint16x8_t i_8_u16 = vmovl_u8(v_8_u8); + sum_8_i16 = vaddq_u16(sum_8_i16, i_8_u16); + } else { + const int8x8_t v_8_i8 = vqmovn_s16(v_8_i16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_s8(v_8_i8), 0); + + // accumulate Sum(a_i) + sum_8_i16 = vaddq_s16(sum_8_i16, v_8_i16); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + qsum = static_cast(vaddvq_u16(sum_8_i16)); + } else { + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + + memset(blob + CountK, 0, BlkLen - (CountK % BlkLen)); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + namespace { @@ -1439,6 +1663,723 @@ SQ4BitGemmKernel_CompInt8( return CountM; } +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + uint32x4_t acc1_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_u32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % MRows2 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + uint32x4_t acc0 = vdupq_n_u32(0U); + uint32x4_t acc1 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + acc1 = vdotq_u32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_u32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + uint32x4_t acc0 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % MRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + #ifdef USE_KLEIDIAI void SQ4BitGemmKernel_Packed_CompInt8( diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp new file mode 100644 index 0000000000000..db040dbb9a08c --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp @@ -0,0 +1,743 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8_i8mm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + input type T1 as float32 and + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8 + using i8mm instructions. + +--*/ + +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + int32x4_t acc1_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_s32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % NRows2 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + int32x4_t acc0 = vdupq_n_s32(0); + int32x4_t acc1 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + acc1 = vusdotq_s32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_s32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + int32x4_t acc0 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 51a8b13cd8261..3b361f155831b 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -586,7 +586,7 @@ struct CudaSyncNotificationImpl : OrtSyncNotificationImpl { Release = ReleaseImpl; } - cudaStream_t& stream_; + cudaStream_t stream_; cudaEvent_t event_; const OrtApi& ort_api; @@ -632,9 +632,9 @@ struct CudaSyncStreamImpl : OrtSyncStreamImpl { *notification_impl = nullptr; std::unique_ptr notification; - cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + cudaStream_t cuda_stream = static_cast(impl.stream_.GetHandle()); - RETURN_IF_ERROR(CudaSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + RETURN_IF_ERROR(CudaSyncNotificationImpl::Create(cuda_stream, impl.ort_api, notification)); *notification_impl = notification.release(); return nullptr; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 93b673f2df5bd..0ee18cc6799fc 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include #include +#include #include #include "core/providers/shared_library/provider_api.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" @@ -84,6 +85,108 @@ struct ShutdownProtobuf { namespace onnxruntime { +// Helper function to check if a data type is supported by NvTensorRTRTX EP +static bool IsSupportedDataType(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point + return true; + default: + return false; + } +} + +// Helper function to get data type name as string +static std::string GetDataTypeName(ONNXTensorElementDataType data_type) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return "FLOAT"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return "FLOAT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return "BFLOAT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return "BOOL"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + return "INT4"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return "INT8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return "UINT8"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return "INT32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return "INT64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + return "FLOAT8E4M3FN"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return "DOUBLE"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return "STRING"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return "UINT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return "UINT32"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return "UINT64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return "INT16"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + return "COMPLEX64"; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + return "COMPLEX128"; + default: + return "UNKNOWN(" + std::to_string(static_cast(data_type)) + ")"; + } +} + +// Helper function to check if a node has supported data types +static bool CheckNodeDataTypes(const Node* node) { + // Check input data types + for (const auto* input_def : node->InputDefs()) { + if (input_def->Exists()) { + const auto* type_proto = input_def->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + auto data_type = static_cast(type_proto->tensor_type().elem_type()); + if (!IsSupportedDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") has unsupported input data type: " << GetDataTypeName(data_type) + << " for input '" << input_def->Name() << "'"; + return false; + } + } + } + } + + // Check output data types + for (const auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto* type_proto = output_def->TypeAsProto(); + if (type_proto && type_proto->has_tensor_type()) { + auto data_type = static_cast(type_proto->tensor_type().elem_type()); + if (!IsSupportedDataType(data_type)) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") has unsupported output data type: " << GetDataTypeName(data_type) + << " for output '" << output_def->Name() << "'"; + return false; + } + } + } + } + + return true; +} + void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -477,10 +580,12 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -561,10 +666,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -623,10 +730,12 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, uint16_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, uint8_t) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -654,9 +763,9 @@ void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fus } } -bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, std::unique_ptr context) { +bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context) { if (!context) { - context = std::make_unique(); + context = tensorrt_ptr::unique_pointer_exec_ctx(); } trt_context_map_[fused_node] = std::move(context); @@ -757,11 +866,11 @@ bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string f nvinfer1::IExecutionContext& NvExecutionProvider::PerThreadContext::GetTensorRTContext(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { - return *(it->second); // dereference shared pointer + return *(it->second.get()); // dereference shared pointer } - auto context = std::make_unique(); + auto context = tensorrt_ptr::unique_pointer_exec_ctx(); trt_context_map_[fused_node] = std::move(context); - return *(trt_context_map_[fused_node]); // dereference shared pointer + return *(trt_context_map_[fused_node].get()); // dereference shared pointer } void NvExecutionProvider::ReleasePerThreadContext() const { @@ -870,6 +979,20 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) max_shared_mem_size_ = info.max_shared_mem_size; dump_subgraphs_ = info.dump_subgraphs; weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; + // make runtime cache path absolute and create directory if it doesn't exist + if (!info.runtime_cache_path.empty()) { + std::filesystem::path p(info.runtime_cache_path); + std::filesystem::path abs_path = std::filesystem::absolute(p); + const auto& env = GetDefaultEnv(); + auto status = env.CreateFolder(abs_path.string()); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The runtime cache directory could not be created at: " << abs_path + << ". Runtime cache is disabled."; + } else { + runtime_cache_ = abs_path; + } + } + onnx_model_folder_path_ = info.onnx_model_folder_path; onnx_model_bytestream_ = info.onnx_bytestream; onnx_model_bytestream_size_ = info.onnx_bytestream_size; @@ -1053,7 +1176,13 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ << ", nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_ << ", nv_use_external_data_initializer_: " << use_external_data_initializer_ - << ", nv_op_types_to_exclude: " << op_types_to_exclude_; + << ", nv_op_types_to_exclude: " << op_types_to_exclude_ + << ", nv_runtime_cache_path: " << runtime_cache_; +} + +Status NvExecutionProvider::Sync() const { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + return Status::OK(); } NvExecutionProvider::~NvExecutionProvider() { @@ -1574,8 +1703,8 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // the initializer was marked as external data by the ORT graph at load time since it was provided in memory size_t size = 0; const void* ptr = nullptr; - c_api.GetTensorSizeInBytes(&initializer_value, &size); - c_api.GetTensorData(&initializer_value, &ptr); + Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); + Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); userWeights.emplace_back(tp->name(), ptr, size); } else if (utils::HasExternalDataInMemory(*tp)) { // only copy and take ownership of the data if none of the above conditions are met @@ -1857,6 +1986,7 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, /* Iterate all the nodes and exclude the node if: * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. It's a DDS op. + * 3. It has unsupported data types. */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); @@ -1896,6 +2026,16 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, supported_node = false; } + // Check data types and print warnings for unsupported types + if (supported_node) { + if (!CheckNodeDataTypes(node)) { + supported_node = false; + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Node '" << node->Name() + << "' (OpType: " << node->OpType() + << ") excluded due to unsupported data types"; + } + } + if (supported_node) { if (new_subgraph) { parser_nodes_vector.emplace_back(); @@ -2394,8 +2534,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // the initializer was marked as external data by the ORT graph at load time since it was provided in memory size_t size = 0; const void* ptr = nullptr; - c_api.GetTensorSizeInBytes(&initializer_value, &size); - c_api.GetTensorData(&initializer_value, &ptr); + Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); + Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); userWeights.emplace_back(tp->name(), ptr, size); } else if (utils::HasExternalDataInMemory(*tp)) { // only copy and take ownership of the data if none of the above conditions are met @@ -2631,8 +2771,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // // Otherwise engine will be handled at inference time. std::unique_ptr trt_engine; - std::unique_ptr trt_context; + tensorrt_ptr::unique_pointer_exec_ctx trt_context; + std::unique_ptr trt_runtime_cache; std::unique_ptr trt_runtime_config; + std::string runtime_cache_file = ""; // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { @@ -2661,6 +2803,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); } trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + if (!runtime_cache_.empty()) { + runtime_cache_file = (runtime_cache_ / fused_node.Name()).string(); + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + auto cache_data = file_utils::ReadFile(runtime_cache_file); + if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) { + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl; + } + if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl; + } + } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); @@ -2721,7 +2875,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_context = std::unique_ptr(trt_engine->createExecutionContext(trt_runtime_config.get())); + trt_context = tensorrt_ptr::unique_pointer_exec_ctx( + trt_engine->createExecutionContext(trt_runtime_config.get()), + tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache))); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); @@ -3002,7 +3158,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra std::unordered_map& output_map, std::vector& node_compute_funcs) { std::unique_ptr trt_engine; - std::unique_ptr trt_context; + tensorrt_ptr::unique_pointer_exec_ctx trt_context; std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index std::unordered_map output_types; // TRT engine output name -> ORT output tensor type @@ -3024,11 +3180,33 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } + std::unique_ptr trt_runtime_cache; + auto trt_runtime_config = std::unique_ptr(trt_engine->createRuntimeConfig()); + if (trt_runtime_config && cuda_graph_enable_) { + trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); + } + trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + std::string runtime_cache_file = ""; + if (!runtime_cache_.empty()) { + runtime_cache_file = (runtime_cache_ / graph_body_viewer.GetNode(node_idx)->Name()).string(); + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + auto cache_data = file_utils::ReadFile(runtime_cache_file); + if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) { + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl; + } + if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl; + } + } + // Build context // // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + trt_context = tensorrt_ptr::unique_pointer_exec_ctx( + trt_engine->createExecutionContext(trt_runtime_config.get()), + tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache))); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 9e5fd03756f02..bb8f687db094f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -16,6 +16,7 @@ typedef void* cudnnStatus_t; #include #include "core/providers/cuda/cuda_graph.h" #include "nv_execution_provider_info.h" +#include "core/providers/nv_tensorrt_rtx/nv_file_utils.h" namespace onnxruntime { @@ -58,6 +59,26 @@ class TensorrtLogger : public nvinfer1::ILogger { }; namespace tensorrt_ptr { +/* + * custom deleter that will dump the optimized runtime cache when the execution context is destructed + */ +struct IExecutionContextDeleter { + IExecutionContextDeleter() = default; + IExecutionContextDeleter(const std::string& runtime_cache_path, std::unique_ptr&& runtime_cache) : runtime_cache_path_(runtime_cache_path), runtime_cache_(std::move(runtime_cache)) {}; + void operator()(nvinfer1::IExecutionContext* context) { + if (context != nullptr) { + if (!runtime_cache_path_.empty()) { + auto serialized_cache_data = std::unique_ptr(runtime_cache_->serialize()); + file_utils::WriteFile(runtime_cache_path_, serialized_cache_data->data(), serialized_cache_data->size()); + } + delete context; + } + } + + private: + std::string runtime_cache_path_; + std::unique_ptr runtime_cache_; +}; struct TensorrtInferDeleter { template @@ -70,6 +91,7 @@ struct TensorrtInferDeleter { template using unique_pointer = std::unique_ptr; +using unique_pointer_exec_ctx = std::unique_ptr; }; // namespace tensorrt_ptr // @@ -196,7 +218,7 @@ struct TensorrtFuncState { std::string fused_node_name; nvinfer1::IBuilder* builder; std::unique_ptr* engine = nullptr; - std::unique_ptr* context = nullptr; + tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; @@ -233,7 +255,7 @@ struct TensorrtShortFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; std::unique_ptr* engine = nullptr; - std::unique_ptr* context = nullptr; + tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr; std::vector> input_info; std::vector> output_info; std::mutex* tensorrt_mu_ptr = nullptr; @@ -285,6 +307,7 @@ class NvExecutionProvider : public IExecutionProvider { IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } + Status Sync() const; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; @@ -356,6 +379,7 @@ class NvExecutionProvider : public IExecutionProvider { bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; bool multi_profile_enable_ = false; + std::filesystem::path runtime_cache_; std::string cache_prefix_; std::string op_types_to_exclude_; int nv_profile_index_ = 0; @@ -386,7 +410,7 @@ class NvExecutionProvider : public IExecutionProvider { // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> engines_; - std::unordered_map> contexts_; + std::unordered_map contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -424,7 +448,7 @@ class NvExecutionProvider : public IExecutionProvider { bool IsTensorRTContextInMap(std::string fused_node); nvinfer1::IExecutionContext& GetTensorRTContext(std::string fused_node); - bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr context); + bool UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context); void ResetTensorRTContext(std::string fused_node); // CUDA Graph management @@ -454,7 +478,7 @@ class NvExecutionProvider : public IExecutionProvider { // See more details here: // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a63cd95430852038ce864e17c670e0b36 - std::unordered_map> trt_context_map_; + std::unordered_map trt_context_map_; // The profile shape ranges for the engine that the execution context maintained by the PerThreadContext is built with. // TRT EP needs this info to determine whether to rebuild the execution context. diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index 527a37f6c2b57..f25718114891b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -51,6 +51,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) .AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer) .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) + .AddAssignmentToReference(nv::provider_option_names::kRuntimeCacheFile, info.runtime_cache_path) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; @@ -105,7 +106,8 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv {nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, - {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}}; + {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}, + {nv::provider_option_names::kRuntimeCacheFile, MakeStringWithClassicLocale(info.runtime_cache_path)}}; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index b826925361b05..372e8196f38c2 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -37,7 +37,7 @@ struct NvExecutionProviderInfo { bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; - std::string timing_cache_path{""}; + std::string runtime_cache_path{""}; bool detailed_build_log{false}; bool sparsity_enable{false}; int auxiliary_streams{-1}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h new file mode 100644 index 0000000000000..159aba0507ffb --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h @@ -0,0 +1,52 @@ +#pragma once +#include +#include +#include +#include +#include +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace file_utils { + +inline std::vector ReadFile(const std::string& path) { + if (!std::filesystem::exists(path)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX could not find the file and will create a new one " << path << std::endl; + return {}; + } + std::ifstream file(path, std::ios::in | std::ios::binary); + if (!file) { + ORT_THROW("Failed to open file: " + path); + } + file.seekg(0, std::ios::end); + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + std::vector buffer(size); + if (size > 0 && !file.read(buffer.data(), size)) { + ORT_THROW("Failed to read file: " + path); + } + return buffer; +} + +inline void WriteFile(const std::string& path, const void* data, size_t size) { + if (std::filesystem::exists(path)) { + std::ofstream file(path, std::ios::out | std::ios::binary | std::ios::trunc); + if (!file) { + ORT_THROW("Failed to open file for writing: " + path); + } + file.write(static_cast(data), size); + } else { + LOGS_DEFAULT(INFO) << "TensorRT RTX a new file cache was written to " << path << std::endl; + // Create new file + std::ofstream file(path, std::ios::out | std::ios::binary); + if (!file) { + ORT_THROW("Failed to create file: " + path); + } + file.write(static_cast(data), size); + } +} + +inline void WriteFile(const std::string& path, const std::vector& data) { WriteFile(path, data.data(), data.size()); } + +} // namespace file_utils +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index d23d50549b2c5..c3fbccef84883 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -431,7 +431,7 @@ struct NvTrtRtxSyncNotificationImpl : OrtSyncNotificationImpl { Release = ReleaseImpl; } - cudaStream_t& stream_; + cudaStream_t stream_; cudaEvent_t event_; const OrtApi& ort_api; @@ -477,9 +477,9 @@ struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { *notification_impl = nullptr; std::unique_ptr notification; - cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + cudaStream_t cuda_stream = static_cast(impl.stream_.GetHandle()); - RETURN_IF_ERROR(NvTrtRtxSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + RETURN_IF_ERROR(NvTrtRtxSyncNotificationImpl::Create(cuda_stream, impl.ort_api, notification)); *notification_impl = notification.release(); return nullptr; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 1f34a0f25877d..c1626fa4f36ad 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -311,13 +311,19 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const Node& node) { ". Please make sure engine cache is in the same directory or sub-directory of context model."); } - std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + size_t file_length = 0; + auto path_str = ToPathString(engine_cache_path.string()); + + Env::MappedMemoryPtr engine_buf; + const auto& env = GetDefaultEnv(); + ORT_RETURN_IF_ERROR(env.GetFileLength(path_str.c_str(), file_length)); + if (!file_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "Nv EP could not read engine from cache: " + engine_cache_path.string()); + } + ORT_RETURN_IF_ERROR(env.MapFileIntoMemory(path_str.c_str(), 0, file_length, engine_buf)); + + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), file_length)); if (!(*trt_engine_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not deserialize engine from cache: " + engine_cache_path.string()); diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 59b0992d827e1..b9a54ea7104e1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -64,7 +64,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModelPath, API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string model_path = PathToUTF8String(input_model_path); + std::filesystem::path model_path = input_model_path; if (model_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: path string is empty"); @@ -113,7 +113,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string model_path = PathToUTF8String(output_model_path); + std::filesystem::path model_path = output_model_path; if (model_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output model path: path is empty"); } @@ -136,17 +136,18 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInf #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - std::string output_dir = PathToUTF8String(output_directory); - if (output_dir.empty()) { + std::filesystem::path output_directory_path = output_directory; + if (output_directory_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); } - std::string model_name_str = ToUTF8String(model_name); - if (model_name_str.empty()) { + std::filesystem::path model_name_path = model_name; + if (model_name_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_directory_path, + model_name_path)); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); @@ -163,7 +164,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExterna size_t external_initializer_size_threshold) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) - std::string initializers_file_path = PathToUTF8String(external_initializers_file_path); + std::filesystem::path initializers_file_path = external_initializers_file_path; if (initializers_file_path.empty()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid external initializer file: path is empty"); } @@ -214,6 +215,50 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (write_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtWriteBufferFunc function for output model is null"); + } + + model_compile_options->SetOutputModelWriteFunc(write_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (get_initializer_location_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtGetInitializerLocationFunc function for output model is null"); + } + + model_compile_options->SetOutputModelGetInitializerLocationFunc(get_initializer_location_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(get_initializer_location_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* ort_model_compile_options, bool embed_ep_context_in_model) { @@ -231,7 +276,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode } ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, - _In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) { + _In_ OrtModelCompilationOptions* ort_model_compile_options, uint32_t flags) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); @@ -245,6 +290,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(graph_optimization_level); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -278,6 +339,9 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetFlags, &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, + &OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelWriteFunc, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 93cc5dbf20fce..34fa06340a7f9 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -29,8 +29,17 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, - size_t flags); + uint32_t flags); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtWriteBufferFunc write_func, _In_ void* state); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index bbb110033f54c..84f41771cb62b 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -7,8 +7,11 @@ #include #include #include +#include +#include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/framework/ep_context_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -22,14 +25,16 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& // defaulting to kGenerateModel to support wider usage. session_options_.value.ep_context_gen_options.action_if_no_compiled_nodes = - EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK()); + + session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only } -void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) { +void ModelCompilationOptions::SetInputModelPath(const std::filesystem::path& input_model_path) { ResetInputModelSettings(); input_model_path_ = input_model_path; } @@ -40,17 +45,16 @@ void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_da input_model_data_size_ = input_model_data_size; } -Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_model_path) { - ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); - +Status ModelCompilationOptions::SetOutputModelPath(const std::filesystem::path& output_model_path) { ConfigOptions& config_options = session_options_.value.config_options; - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path = output_model_path; + ep_context_gen_options.output_model_location = output_model_path; - if (ep_context_gen_options.output_model_file_path.size() <= ConfigOptions::kMaxValueLength) { - Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ep_context_gen_options.output_model_file_path.c_str()); + std::string output_model_path_str = PathToUTF8String(output_model_path); + + if (output_model_path_str.size() <= ConfigOptions::kMaxValueLength) { + Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_path_str.c_str()); ORT_ENFORCE(status.IsOK()); // Should not fail because both key/value strings are below the min string lengths // required by ConfigOptions::AddConfigEntry(). } else { @@ -71,7 +75,7 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() + LOGS(logger, WARNING) << "Output model path length (" << output_model_path_str.size() << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." << "ORT will still generate the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; @@ -80,40 +84,58 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod return Status::OK(); } -void ModelCompilationOptions::SetOutputModelExternalInitializersFile(const std::string& external_initializers_path, - size_t external_initializer_size_threshold) { - session_options_.value.ep_context_gen_options.output_external_initializers_file_path = external_initializers_path; - session_options_.value.ep_context_gen_options.output_external_initializer_size_threshold = - external_initializer_size_threshold; +void ModelCompilationOptions::SetOutputModelExternalInitializersFile( + const std::filesystem::path& external_initializers_path, + size_t external_initializer_size_threshold) { + session_options_.value.ep_context_gen_options.initializers_location = epctx::ExternalInitializerFileInfo{ + external_initializers_path, + external_initializer_size_threshold, + }; } Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { - ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferHolder{ + output_model_buffer_ptr, + output_model_buffer_size_ptr, + std::move(allocator), + }; - session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); return Status::OK(); } -Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, - const std::string& model_name) { +void ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state) { + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferWriteFuncHolder{ + write_func, + state, + }; +} + +void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, void* state) { + session_options_.value.ep_context_gen_options.initializers_location = epctx::InitializerHandler{ + get_initializer_location_func, + state, + }; +} + +Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory, + const std::filesystem::path& model_name) { if (output_directory.empty() || model_name.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); } - std::filesystem::path output_dir_path(output_directory); - if (output_dir_path.has_filename() && output_dir_path.extension() == "") { + if (output_directory.has_filename() && output_directory.extension() == "") { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); } - std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); + std::filesystem::path ctx_model_path = output_directory / model_name; + std::string ctx_model_path_str = PathToUTF8String(ctx_model_path); - if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { + if (ctx_model_path_str.size() <= ConfigOptions::kMaxValueLength) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ctx_model_path.string().c_str())); + ctx_model_path_str.c_str())); } else { logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { @@ -135,12 +157,12 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m return Status::OK(); } -Status ModelCompilationOptions::SetFlags(size_t flags) { - EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options; +Status ModelCompilationOptions::SetFlags(uint32_t flags) { + epctx::ModelGenOptions& options = session_options_.value.ep_context_gen_options; options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS; options.action_if_no_compiled_nodes = - (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kReturnError - : EpContextModelGenerationOptions::ActionIfNoCompiledNodes::kGenerateModel; + (flags & OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) ? epctx::ModelGenOptions::ActionIfNoCompiledNodes::kReturnError + : epctx::ModelGenOptions::ActionIfNoCompiledNodes::kGenerateModel; return Status::OK(); } @@ -152,7 +174,7 @@ bool ModelCompilationOptions::InputModelComesFromFile() const { return !input_model_path_.empty(); } -const std::string& ModelCompilationOptions::GetInputModelPath() const { +const std::filesystem::path& ModelCompilationOptions::GetInputModelPath() const { return input_model_path_; } @@ -170,77 +192,106 @@ void ModelCompilationOptions::ResetInputModelSettings() { input_model_data_size_ = 0; } -Status ModelCompilationOptions::ResetOutputModelSettings() { - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path.clear(); - ep_context_gen_options.output_model_buffer_ptr = nullptr; - ep_context_gen_options.output_model_buffer_size_ptr = nullptr; - ep_context_gen_options.output_model_buffer_allocator = nullptr; +Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + switch (graph_optimization_level) { + case ORT_DISABLE_ALL: + // TransformerLevel::Default means that we only run required transformers. + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default; + break; + case ORT_ENABLE_BASIC: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1; + break; + case ORT_ENABLE_EXTENDED: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2; + break; + case ORT_ENABLE_LAYOUT: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3; + break; + case ORT_ENABLE_ALL: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel; + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ", + static_cast(graph_optimization_level), " is invalid. Valid values are: ", + "ORT_DISABLE_ALL (0), ORT_ENABLE_BASIC (1), ORT_ENABLE_EXTENDED (2), ", + "ORT_ENABLE_LAYOUT (3), and ORT_ENABLE_ALL (99)."); + } + return Status::OK(); } -Status ModelCompilationOptions::CheckInputModelSettings() const { - const bool comes_from_file = !input_model_path_.empty(); - const bool comes_from_memory = input_model_data_ != nullptr; +Status ModelCompilationOptions::Check() const { + const ConfigOptions& config_options = session_options_.value.config_options; + + ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); + ORT_ENFORCE(config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - if (!comes_from_file && !comes_from_memory) { + // Check input model settings. + const bool input_from_file = !input_model_path_.empty(); + const bool input_from_memory = input_model_data_ != nullptr; + + if (!input_from_file && !input_from_memory) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model to compile must be loaded from either a file or a memory buffer"); } - if (comes_from_file && comes_from_memory) { + if (input_from_file && input_from_memory) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model to compile must be loaded from either a file or a memory buffer, ", "but not both."); } - if (comes_from_file && !std::filesystem::exists(input_model_path_)) { + if (input_from_file && !std::filesystem::exists(input_model_path_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", input_model_path_); } - if (comes_from_memory && input_model_data_size_ == 0) { + if (input_from_memory && input_model_data_size_ == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } - return Status::OK(); -} + // Check output model settings. + const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + bool has_no_output_model_location = std::holds_alternative( + ep_context_gen_options.output_model_location); -Status ModelCompilationOptions::CheckOutputModelSettings() const { - const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - - const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty(); - const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; - - if (!explicit_writes_to_file && !writes_to_buffer) { - // User did not specify an output file or an output buffer. We default to generating an output file - // with a name based on the input file name, so do not return an error. + if (has_no_output_model_location && input_from_file) { + // User did not specify an output file, an output buffer, or an output write function. We default to generating an + // output file with a name based on the input file name, so do not return an error. return Status::OK(); } - if (explicit_writes_to_file && writes_to_buffer) { + if (has_no_output_model_location) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Output model to compile must be saved either to a file or to a buffer, but not both."); + "Unable to generate an output model path: require an input model path if the location " + "of the output model (e.g., file, buffer, or stream) is not specified."); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { + const epctx::BufferHolder* output_buffer_ptr = ep_context_gen_options.TryGetOutputModelBuffer(); + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_ptr == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer configuration for output model: buffer pointer is null"); + } + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_size_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: size pointer is null"); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_allocator == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: allocator is null"); } - return Status::OK(); -} + const epctx::BufferWriteFuncHolder* output_write_func_holder = ep_context_gen_options.TryGetOutputModelWriteFunc(); + + if (output_write_func_holder != nullptr && output_write_func_holder->write_func == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid buffer writing function for output model: function pointer is null"); + } -Status ModelCompilationOptions::Check() const { - ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); - ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - ORT_RETURN_IF_ERROR(CheckInputModelSettings()); - ORT_RETURN_IF_ERROR(CheckOutputModelSettings()); return Status::OK(); } + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 2824df863013d..45323e6cb13c5 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) #pragma once +#include #include #include #include "core/common/status.h" @@ -34,7 +35,7 @@ class ModelCompilationOptions { /// Overrides any previous call to SetInputModelPath() or SetInputModelFromBuffer(). /// /// The input model's path - void SetInputModelPath(const std::string& input_model_path); + void SetInputModelPath(const std::filesystem::path& input_model_path); /// /// Sets the buffer that stores the input ONNX model to compile. @@ -50,7 +51,7 @@ class ModelCompilationOptions { /// /// /// Status indicating potential error - Status SetOutputModelPath(const std::string& output_model_path); + Status SetOutputModelPath(const std::filesystem::path& output_model_path); /// /// Sets the file path to the file that will store external ONNX initializers for the compiled model. @@ -58,7 +59,7 @@ class ModelCompilationOptions { /// /// Path to the external initializers file to generate /// Initializers that exceed this threshold are external - void SetOutputModelExternalInitializersFile(const std::string& external_initializers_path, + void SetOutputModelExternalInitializersFile(const std::filesystem::path& external_initializers_path, size_t external_initializer_size_threshold); /// @@ -72,6 +73,21 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets an output stream (write function + state) used to write out the compiled model bytes. + /// + /// Write function + /// The user's state + void SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + + /// + /// Sets a user-provided function to handle serialization of ONNX initializers. + /// + /// The user-provided function called for every initializer + /// The user's state. + void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + /// /// Sets information relate to EP context binary file. /// EP use this information to decide the location and context binary file name. @@ -80,7 +96,8 @@ class ModelCompilationOptions { /// The folder path to the generated context binary file /// Model name used to decide the context binary file name: [model_name]_[ep].bin /// Status indicating potential error - Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); + Status SetEpContextBinaryInformation(const std::filesystem::path& output_directory, + const std::filesystem::path& model_name); /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext @@ -95,7 +112,7 @@ class ModelCompilationOptions { /// /// unsigned integer set to the bitwise OR of enabled flags. /// Status indicating success or an error - Status SetFlags(size_t flags); + Status SetFlags(uint32_t flags); /// /// Returns a reference to the session options object. @@ -107,7 +124,7 @@ class ModelCompilationOptions { /// Returns the file path to the input ONNX model. /// /// input model's path - const std::string& GetInputModelPath() const; + const std::filesystem::path& GetInputModelPath() const; /// /// Returns true if the input model is read from a file. @@ -129,6 +146,13 @@ class ModelCompilationOptions { /// input model buffer's size in bytes size_t GetInputModelDataSize() const; + /// + /// Sets the graph optimization level for the underlying session that compiles the model. + /// + /// The optimization level + /// + Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); + /// /// Checks if the compilation options described by this object are valid. /// @@ -137,13 +161,10 @@ class ModelCompilationOptions { private: void ResetInputModelSettings(); - Status ResetOutputModelSettings(); - Status CheckInputModelSettings() const; - Status CheckOutputModelSettings() const; const onnxruntime::Environment& env_; OrtSessionOptions session_options_; - std::string input_model_path_; + std::filesystem::path input_model_path_; const void* input_model_data_ = nullptr; size_t input_model_data_size_ = 0; }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f3e2a8ce7ba7b..36f7f1f60c36e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2538,6 +2538,23 @@ ORT_API(void, OrtApis::ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExtern delete static_cast(info); } +ORT_API_STATUS_IMPL(OrtApis::CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, + _In_ int64_t file_offset, _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto ext_data_info = std::make_unique(filepath, file_offset, byte_size); + *out = ext_data_info.release(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(filepath); + ORT_UNUSED_PARAMETER(file_offset); + ORT_UNUSED_PARAMETER(byte_size); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateExternalInitializerInfo() is not supported in this build."); +#endif + API_IMPL_END +} + ORT_API(const ORTCHAR_T*, OrtApis::ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info) { return info->GetRelPath().c_str(); } @@ -4202,6 +4219,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetModelMetadata, &OrtApis::GetModelCompatibilityForEpDevices, + &OrtApis::CreateExternalInitializerInfo, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 6dc4cf9d195cc..78616c7b3973e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -700,6 +700,8 @@ ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_may // OrtExternalInitializerInfo ORT_API(void, ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info); +ORT_API_STATUS_IMPL(CreateExternalInitializerInfo, _In_ const ORTCHAR_T* filepath, _In_ int64_t file_offset, + _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); ORT_API(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); ORT_API(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); ORT_API(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc index d6e51a44c1c69..42b65239de92c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -4,6 +4,8 @@ #include "core/session/plugin_ep/ep_factory_provider_bridge.h" #include "core/providers/shared_library/provider_host_api.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace onnxruntime { OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, @@ -20,6 +22,11 @@ OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_fa auto* ep_device = ep_devices[i]; if (ep_device) { ep_device->ep_factory = &ep_factory; + + // Add library path to EP metadata if available + if (library_path_.has_value()) { + ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string()); + } } } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 437af62dc2c0c..8c5ef526baba1 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" @@ -12,12 +16,14 @@ namespace onnxruntime { class ProviderBridgeEpFactory : public EpFactoryInternalImpl { public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library, + std::optional library_path = std::nullopt) : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), ep_factory.GetVendor(&ep_factory), ep_factory.GetVendorId(&ep_factory)), ep_factory_{ep_factory}, - provider_library_{provider_library} { + provider_library_{provider_library}, + library_path_{std::move(library_path)} { } private: @@ -59,8 +65,9 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP + OrtEpFactory& ep_factory_; + ProviderLibrary& provider_library_; + std::optional library_path_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h index 24ab74e1c77fc..af5bc23143e33 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library.h +++ b/onnxruntime/core/session/plugin_ep/ep_library.h @@ -23,6 +23,7 @@ class EpLibrary { virtual Status Load() { return Status::OK(); } virtual const std::vector& GetFactories() = 0; // valid after Load() virtual Status Unload() { return Status::OK(); } + virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc index 06cf54aea4071..da94a9f12ba9d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -4,6 +4,7 @@ #include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/plugin_ep/ep_factory_provider_bridge.h" +#include "core/session/plugin_ep/ep_library_plugin.h" namespace onnxruntime { Status EpLibraryProviderBridge::Load() { @@ -26,8 +27,9 @@ Status EpLibraryProviderBridge::Load() { // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); + auto factory_impl = std::make_unique(*factory, *provider_library_, library_path_); auto internal_factory = std::make_unique(std::move(factory_impl)); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index c7e8ebefc3785..45277b2828f56 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -21,9 +21,11 @@ namespace onnxruntime { class EpLibraryProviderBridge : public EpLibrary { public: EpLibraryProviderBridge(std::unique_ptr provider_library, - std::unique_ptr ep_library_plugin) + std::unique_ptr ep_library_plugin, + std::optional library_path = std::nullopt) : provider_library_{std::move(provider_library)}, - ep_library_plugin_{std::move(ep_library_plugin)} { + ep_library_plugin_{std::move(ep_library_plugin)}, + library_path_{std::move(library_path)} { } const char* RegistrationName() const override { @@ -53,6 +55,9 @@ class EpLibraryProviderBridge : public EpLibrary { // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; + // Library path for EP metadata + std::optional library_path_; + std::vector> factories_; std::vector factory_ptrs_; // for convenience std::vector internal_factory_ptrs_; // for convenience diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d4041dfce5a7a..444027692903c 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -136,13 +136,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file if (options && model_path == nullptr) { - EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); // This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case // the user used the older SessionOptions' configuration entries to generate a compiled model. - if (ep_ctx_gen_options.enable && - ep_ctx_gen_options.output_model_file_path.empty() && - ep_ctx_gen_options.output_model_buffer_ptr == nullptr) { + if (ep_ctx_gen_options.enable && !ep_ctx_gen_options.HasOutputModelLocation()) { return OrtApis::CreateStatus(ORT_FAIL, "Inference session was configured with EPContext model generation enabled but " "without a valid location (e.g., file or buffer) for the output model. " @@ -383,7 +381,7 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); if (model_compile_options.InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + const std::filesystem::path& input_model_path = model_compile_options.GetInputModelPath(); ORT_RETURN_IF_ERROR(ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session))); @@ -421,13 +419,14 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, << (is_provider_bridge ? " as a provider bridge" : " as a plugin"); // create EpLibraryPlugin to ensure CreateEpFactories and ReleaseEpFactory are available - auto ep_library_plugin = std::make_unique(registration_name, std::move(resolved_library_path)); + auto ep_library_plugin = std::make_unique(registration_name, resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_plugin->Load()); if (is_provider_bridge) { // wrap the EpLibraryPlugin with EpLibraryProviderBridge to add to directly create an IExecutionProvider auto ep_library_provider_bridge = std::make_unique(std::move(provider_library), - std::move(ep_library_plugin)); + std::move(ep_library_plugin), + resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_provider_bridge->Load()); internal_factories = ep_library_provider_bridge->GetInternalFactories(); ep_library = std::move(ep_library_provider_bridge); diff --git a/onnxruntime/core/util/shape_checker.h b/onnxruntime/core/util/shape_checker.h index 9c975275c45b9..89c20deb8f649 100644 --- a/onnxruntime/core/util/shape_checker.h +++ b/onnxruntime/core/util/shape_checker.h @@ -27,6 +27,8 @@ TensorShape make_shape(Args... args) { } \ } +#define CHECK_TENSOR_SHAPE ASSERT_TENSOR_DIMS + // This assumes the tensor is optional, and check wether its shape is expected. #define ASSERT_TENSOR_SHAPE(tensor, shape) \ if (tensor != nullptr) { \ @@ -60,4 +62,31 @@ TensorShape make_shape(Args... args) { } \ } +#define ASSERT_TENSOR_DIMENSION(tensor, dim) \ + if (tensor != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + const auto tensor_dimensions = tensor->Shape().NumDimensions(); \ + if (tensor_dimensions != dim) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, "Input '" #tensor "' is expected to have " #dim " dimensions, got ", \ + tensor_dimensions); \ + } \ + } + +#define ASSERT_TENSOR_DIMENSION_2_CHOICES(tensor, choice1, choice2) \ + if ((tensor) != nullptr) { \ + static_assert(std::is_same::value, "tensor must be a pointer to a Tensor"); \ + const auto tensor_dimensions = tensor->Shape().NumDimensions(); \ + if (tensor_dimensions != choice1 && tensor_dimensions != choice2) { \ + return ORT_MAKE_STATUS( \ + ONNXRUNTIME, INVALID_ARGUMENT, \ + "Input '" #tensor "' is expected to have " #choice1 " or ", #choice2, " dimensions, got ", \ + tensor_dimensions); \ + } \ + } + +#define ASSERT_TENSOR_2D(tensor) ASSERT_TENSOR_DIMENSION(tensor, 2) +#define ASSERT_TENSOR_3D(tensor) ASSERT_TENSOR_DIMENSION(tensor, 3) +#define ASSERT_TENSOR_2D_OR_3D(tensor) ASSERT_TENSOR_DIMENSION_2_CHOICES(tensor, 2, 3) + } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 64c4ada07f28f..35abad5760c32 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -9,7 +9,7 @@ import os import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any from onnxruntime.capi import _pybind_state as C @@ -620,6 +620,36 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options, C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1]) +def make_get_initializer_location_func_wrapper( + get_initializer_location_func: GetInitializerLocationFunc, +) -> GetInitializerLocationWrapperFunc: + """ + Wraps a user's "get initializer location" function. The returned wrapper function adheres to the + signature expected by ORT. + + Need this wrapper to: + - Convert the `initializer_value` parameter from `C.OrtValue` to `onnxruntime.OrtValue`, which is more + convenient for the user's function to use. + - Allow the user's function to return the original `external_info` parameter (this wrapper makes a copy) + """ + + def get_initializer_location_func_wrapper( + initializer_name: str, + initializer_value: C.OrtValue, + external_info: C.OrtExternalInitializerInfo | None, + ) -> C.OrtExternalInitializerInfo | None: + ret_val: C.OrtExternalInitializerInfo | None = get_initializer_location_func( + initializer_name, OrtValue(initializer_value), external_info + ) + if ret_val is not None and ret_val == external_info: + # User returned `external_info` (const and owned by ORT). ORT expects the returned value to be + # a new instance (that it deletes), so make a copy. + ret_val = C.OrtExternalInitializerInfo(ret_val.filepath, ret_val.file_offset, ret_val.byte_size) + return ret_val + + return get_initializer_location_func_wrapper + + class ModelCompiler: """ This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each @@ -647,6 +677,8 @@ def __init__( external_initializers_file_path: str | os.PathLike | None = None, external_initializers_size_threshold: int = 1024, flags: int = C.OrtCompileApiFlags.NONE, + graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL, + get_initializer_location_func: GetInitializerLocationFunc | None = None, ): """ Creates a ModelCompiler instance. @@ -663,6 +695,27 @@ def __init__( is None or empty. Initializers larger than this threshold are stored in the external initializers file. :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of flags in onnxruntime.OrtCompileApiFlags. + :param graph_optimization_level: The graph optimization level. + Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL. + :param get_initializer_location_func: Optional function called for every initializer to allow user to specify + whether an initializer should be stored within the model or externally. Example: + ``` + def get_initializer_location( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + byte_size = initializer_value.tensor_size_in_bytes() + + if byte_size < 64: + return None # Store small initializer within compiled model. + + # Else, write initializer to new external file. + value_np = initializer_value.numpy() + file_offset = ext_init_file.tell() + ext_init_file.write(value_np.tobytes()) + return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size) + ``` """ input_model_path: str | os.PathLike | None = None input_model_bytes: bytes | None = None @@ -685,6 +738,18 @@ def __init__( else: external_initializers_file_path = "" + if get_initializer_location_func is not None: + if external_initializers_file_path: + raise ValueError( + "Cannot initialize ModelCompiler with both `external_initializers_file_path` " + "and `get_initializer_location_func`" + ) + self.get_initializer_location_func_wrapper = make_get_initializer_location_func_wrapper( + get_initializer_location_func + ) + else: + self.get_initializer_location_func_wrapper = None + if input_model_path: self._model_compiler = C.ModelCompiler( sess_options, @@ -694,6 +759,8 @@ def __init__( external_initializers_file_path, external_initializers_size_threshold, flags, + graph_optimization_level, + self.get_initializer_location_func_wrapper, ) else: self._model_compiler = C.ModelCompiler( @@ -704,6 +771,8 @@ def __init__( external_initializers_file_path, external_initializers_size_threshold, flags, + graph_optimization_level, + self.get_initializer_location_func_wrapper, ) def compile_to_file(self, output_model_path: str | None = None): @@ -733,6 +802,14 @@ def compile_to_bytes(self) -> bytes: """ return self._model_compiler.compile_to_bytes() + def compile_to_stream(self, write_function: Callable[[bytes], None]): + """ + Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function. + Raises an 'InvalidArgument' exception if the compilation options are invalid. + :param write_function: A callable that accepts a bytes buffer to write. + """ + self._model_compiler.compile_to_stream(write_function) + class IOBinding: """ @@ -1293,3 +1370,14 @@ def device_name(self) -> str: Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda """ return self._tensor.device_name().lower() + + +# Type hint for user-specified function that allows the user to specify initializer locations when compiling a model. +GetInitializerLocationFunc = Callable[ + [str, OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None +] + +# Type hint that adheres to the signature expected by ORT. +GetInitializerLocationWrapperFunc = Callable[ + [str, C.OrtValue, C.OrtExternalInitializerInfo | None], C.OrtExternalInitializerInfo | None +] diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index e2b069b01f95b..6ff252b5d1353 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" #include @@ -8,11 +9,56 @@ #include #include "core/common/common.h" #include "core/framework/error_code_helper.h" +#include "core/graph/abi_graph_types.h" #include "core/session/utils.h" namespace onnxruntime { namespace python { +/// +/// This function is called by ORT to allow the user to handle where every initializer is stored +/// (i.e., externally or internally). This function wraps (and calls) the actual Python function +/// provided by the user. +/// +/// Opaque state that holds a pointer to the user's Python function. +/// The name of the initializer to handle. +/// The OrtValue with the initializer's data, type, and shape. +/// The original external location of the initializer, if any. May be null. +/// Output parameter set to the initializer's new external location. Function may +/// return NULL if the initializer should be stored within the compiled ONNX model. +/// A status indicating success or an error. +static OrtStatus* ORT_API_CALL PyGetInitializerLocationFuncWrapper( + void* state, + const char* initializer_name, + const OrtValue* initializer_value, + const OrtExternalInitializerInfo* external_info, + /*out*/ OrtExternalInitializerInfo** new_external_info) { + PyGetInitializerLocationFunc* py_func = reinterpret_cast(state); + OrtStatus* status = nullptr; + std::shared_ptr py_new_external_info = nullptr; + + // Call the Python function and convert any exceptions to a status. + ORT_TRY { + py_new_external_info = (*py_func)(initializer_name, *initializer_value, external_info); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + if (py_new_external_info) { + // ORT expects to take ownership of the new external info, so make a copy because other Python code + // may be holding a reference to the `py_new_external_info`. + auto py_result_copy = std::make_unique(*py_new_external_info.get()); + *new_external_info = py_result_copy.release(); + } else { + *new_external_info = nullptr; + } + + return status; +} + onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, const PySessionOptions& sess_options, @@ -20,8 +66,11 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{}); + uint32_t flags, + GraphOptimizationLevel graph_optimization_level, + const PyGetInitializerLocationFunc& py_get_initializer_location_func) { + auto model_compiler = std::make_unique(env, sess_options, py_get_initializer_location_func, + PrivateConstructorTag{}); ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; if (input_model_is_path) { @@ -43,6 +92,14 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrpy_get_initializer_location_func_) { + compile_options.SetOutputModelGetInitializerLocationFunc( + PyGetInitializerLocationFuncWrapper, + reinterpret_cast(&model_compiler->py_get_initializer_location_func_)); + } + out = std::move(model_compiler); return Status::OK(); } @@ -77,9 +134,47 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) return Status::OK(); } +/// +/// Function called by ORT to allow the user to write out the compiled ONNX model bytes to a custom output stream. +/// This function wraps (and calls) the actual Python function provided by the user. +/// +/// Opaque state that holds a pointer to the user's Python function. +/// The buffer to write out. Contains a portion of the compiled ONNX model's bytes. +/// The number of bytes in the buffer. +/// A status indicating success or an error. +static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer, + size_t buffer_num_bytes) { + PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state); + OrtStatus* status = nullptr; + + // Call the Python write function and convert any exceptions to a status. + ORT_TRY { + pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes); + (*py_write_func)(py_bytes); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + return status; +} + +onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) { + model_compile_options_.SetOutputModelWriteFunc(PyOutStreamWriteFuncWrapper, + reinterpret_cast(&write_func)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(env_, model_compile_options_)); + return Status::OK(); +} + PyModelCompiler::PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options, + const PyGetInitializerLocationFunc& py_get_initializer_location_func, PrivateConstructorTag) - : env_(env), model_compile_options_(env, sess_options) { + : env_(env), + model_compile_options_(env, sess_options), + py_get_initializer_location_func_(py_get_initializer_location_func) { } } // namespace python } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h index e61ae4674210b..957350accdba2 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -3,7 +3,6 @@ // Licensed under the MIT License. #pragma once -#if !defined(ORT_MINIMAL_BUILD) #include #include #include "core/common/status.h" @@ -14,11 +13,24 @@ namespace onnxruntime { class Environment; namespace python { +// Type of the function provided by Python code that is called by ORT to write out the compiled model. +using PyOutStreamWriteFunc = std::function; + +// Type of the function provided by Python code that is called by ORT to handle every initializer. +using PyGetInitializerLocationFunc = std::function( + const std::string& initializer_name, + const OrtValue& initializer_value, + const OrtExternalInitializerInfo* external_info)>; + /// /// Class exposed to Python that enables compiling ONNX models. /// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. /// class PyModelCompiler { +#if defined(ORT_MINIMAL_BUILD) + public: + bool not_defined_in_this_build{}; // Prevent empty class warning. +#else private: // private tag to pass to constructor to ensure that constructor cannot be directly called externally struct PrivateConstructorTag {}; @@ -35,9 +47,12 @@ class PyModelCompiler { /// True to embed compiled binary data into EPContext nodes. /// The file into which to store initializers for non-compiled /// nodes. - /// Flags from OrtCompileApiFlags /// Ignored if 'external_initializers_file_path' is empty. /// Initializers with a size greater than this threshold are dumped into the external file. + /// Flags from OrtCompileApiFlags + /// Optimization level for graph transformations on the model. + /// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model. + /// User's function to handle saving of initializers. /// A Status indicating error or success. static onnxruntime::Status Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, @@ -46,11 +61,14 @@ class PyModelCompiler { bool embed_compiled_data_into_model = false, const std::string& external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, - size_t flags = 0); + uint32_t flags = 0, + GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL, + const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr); // Note: Creation should be done via Create(). This constructor is public so that it can be called from // std::make_shared(). PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options, + const PyGetInitializerLocationFunc& py_get_initializer_location_func, PrivateConstructorTag); /// @@ -70,11 +88,19 @@ class PyModelCompiler { /// A Status indicating error or success. onnxruntime::Status CompileToBytes(std::string& output_buffer); + /// + /// Compiles the input model and writes the result into the provided output stream (write functor). + /// + /// Write functor that encapsulates the stream's state. + /// A Status indicating error or success. + onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func); + private: onnxruntime::Environment& env_; onnxruntime::ModelCompilationOptions model_compile_options_; std::string input_model_bytes_; + PyGetInitializerLocationFunc py_get_initializer_location_func_; +#endif // defined(ORT_MINIMAL_BUILD) }; } // namespace python } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index eb06a65ad5330..e370518b1fffb 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -6,11 +6,8 @@ #include #include "python/onnxruntime_pybind_exceptions.h" #include "python/onnxruntime_pybind_mlvalue.h" -#include "python/onnxruntime_pybind_state_common.h" - -#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" -#endif // !defined(ORT_MINIMAL_BUILD) +#include "python/onnxruntime_pybind_state_common.h" #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API @@ -45,6 +42,7 @@ #include "core/session/lora_adapters.h" #if !defined(ORT_MINIMAL_BUILD) +#include "core/graph/abi_graph_types.h" #include "core/session/abi_devices.h" #include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" @@ -2730,6 +2728,35 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + // Must use a std::shared_ptr to hold OrtExternalInitializerInfo because the same instances is passed + // between C++ and Python (and Python cannot transfer ownership to C++). + py::class_> ort_external_initializer_info_binding( + m, "OrtExternalInitializerInfo", + R"pbdoc(Location information for initializer data stored in an external file)pbdoc"); + ort_external_initializer_info_binding + .def(py::init([](const std::basic_string& filepath, int64_t file_offset, size_t byte_size) { +#if !defined(ORT_MINIMAL_BUILD) + return std::make_shared(filepath, file_offset, byte_size); +#else + ORT_UNUSED_PARAMETER(filepath); + ORT_UNUSED_PARAMETER(file_offset); + ORT_UNUSED_PARAMETER(byte_size); + ORT_THROW("OrtExternalInitializerInfo creation is not supported in this build"); +#endif + })) + .def_property_readonly( + "filepath", + [](OrtExternalInitializerInfo* info) -> std::basic_string { return info->GetRelPath(); }, + R"pbdoc(The relative path to the file in which initializer data is stored.)pbdoc") + .def_property_readonly( + "file_offset", + [](OrtExternalInitializerInfo* info) -> int64_t { return info->GetOffset(); }, + R"pbdoc(The file byte offset where the initializer data is stored.)pbdoc") + .def_property_readonly( + "byte_size", + [](OrtExternalInitializerInfo* info) -> size_t { return info->GetLength(); }, + R"pbdoc(The byte size of the initializer data in the file.)pbdoc"); + py::enum_(m, "OrtCompileApiFlags", py::arithmetic()) .value("NONE", OrtCompileApiFlags_NONE) .value("ERROR_IF_NO_NODES_COMPILED", OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED) @@ -2743,7 +2770,9 @@ including arg name, arg type (contains both type and shape).)pbdoc") bool embed_compiled_data_into_model = false, std::string external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, - size_t flags = OrtCompileApiFlags_NONE) { + uint32_t flags = OrtCompileApiFlags_NONE, + GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL, + const PyGetInitializerLocationFunc& py_get_initializer_location_func = nullptr) { #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr result; OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, @@ -2751,7 +2780,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, - flags)); + flags, graph_optimization_level, + py_get_initializer_location_func)); return result; #else ORT_UNUSED_PARAMETER(sess_options); @@ -2761,6 +2791,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_UNUSED_PARAMETER(external_initializers_file_path); ORT_UNUSED_PARAMETER(external_initializers_size_threshold); ORT_UNUSED_PARAMETER(flags); + ORT_UNUSED_PARAMETER(graph_optimization_level); + ORT_UNUSED_PARAMETER(py_get_initializer_location_func); ORT_THROW("Compile API is not supported in this build."); #endif })) @@ -2788,7 +2820,19 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("Compile API is not supported in this build."); #endif }, - R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc") + .def( + "compile_to_stream", + [](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(py_stream_write_func); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); } bool InitArray() { diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 287eba05a0595..e4265713d2d0a 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -33,95 +33,88 @@ struct MulKernel { return iter != float_initializers.end() ? &iter->second : nullptr; } - OrtStatus* GetInputDataAndShape(OrtKernelContext* kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - const OrtValue* input = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, index, &input)); - - OrtTensorTypeAndShapeInfo* type_shape = nullptr; - DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); - - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input, &type_shape)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); - - size_t num_elems = 0; - RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); - - size_t num_dims = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); - - shape.resize(num_dims, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, shape.data(), shape.size())); - - const void* raw_data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorData(input, &raw_data)); - - const float* float_data = static_cast(raw_data); + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); data = gsl::span(float_data, num_elems); - return nullptr; + shape = type_shape.GetShape(); } - OrtStatus* Compute(OrtKernelContext* kernel_context) { + OrtStatus* Compute(OrtKernelContext* kernel_ctx) { RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 1, input1, shape1)); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input1, shape1)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); + input0 = gsl::span(const_input0->data); + shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); + shape1 = const_input1->shape; + } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - RETURN_IF_ERROR(GetInputDataAndShape(kernel_context, 0, input0, shape0)); input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; shape1 = const_input1->shape; } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); - - input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); - shape0 = const_input0->shape; - shape1 = const_input1->shape; - } - RETURN_IF(shape0 != shape1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = 0; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); - RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - OrtValue* output = nullptr; - float* output_data = nullptr; - RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, shape0.data(), shape0.size(), &output)); - RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; @@ -183,178 +176,175 @@ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept return ep->name_.c_str(); } -OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { - size_t num_initializers = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumInitializers(graph, &num_initializers)); - - std::vector initializers(num_initializers); - RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, initializers.data(), initializers.size())); - - for (const OrtValueInfo* initializer : initializers) { - bool is_constant = false; - RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(initializer, &is_constant)); - - if (is_constant) { - const char* name = nullptr; - const OrtValue* value = nullptr; - OrtTensorTypeAndShapeInfo* type_shape = nullptr; - DeferOrtRelease release_type(&type_shape, ort_api.ReleaseTensorTypeAndShapeInfo); - size_t num_elems = 0; +OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* ort_graph) { + Ort::ConstGraph graph{ort_graph}; - RETURN_IF_ERROR(ort_api.GetValueInfoName(initializer, &name)); - RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer, &value)); - RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(value, &type_shape)); - RETURN_IF_ERROR(ort_api.GetTensorShapeElementCount(type_shape, &num_elems)); + try { + std::vector initializers = graph.GetInitializers(); - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); - RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 initializers"); + for (const auto& initializer : initializers) { + const bool is_constant = initializer.IsConstantInitializer(); - size_t num_dims = 0; - RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape, &num_dims)); + if (is_constant) { + auto name = initializer.GetName(); + Ort::ConstValue value; + auto status = initializer.GetInitializer(value); + if (!status.IsOK()) + return status.release(); - std::vector dims(num_dims, 0); - RETURN_IF_ERROR(ort_api.GetDimensions(type_shape, dims.data(), dims.size())); + auto type_shape = value.GetTensorTypeAndShapeInfo(); + const size_t num_elems = type_shape.GetElementCount(); + const ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release(); - const float* data = nullptr; - RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(value), (void**)&data)); + std::vector dims = type_shape.GetShape(); + const float* data = value.GetTensorData(); - FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; - float_initializers_.emplace(name, std::move(ep_initializer)); + FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; + float_initializers_.emplace(std::move(name), std::move(ep_initializer)); + } } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; } /*static*/ -OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, +OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, OrtEpGraphSupportInfo* graph_support_info) noexcept { - ExampleEp* ep = static_cast(this_ptr); + try { + ExampleEp* ep = static_cast(this_ptr); - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graph, &num_nodes)); + Ort::ConstGraph graph{ort_graph}; + std::vector nodes = graph.GetNodes(); + if (nodes.empty()) { + return nullptr; // No nodes to process + } - if (num_nodes == 0) { - return nullptr; // No nodes to process - } + std::vector supported_nodes; - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); - - std::vector supported_nodes; - - for (const OrtNode* node : nodes) { - const char* op_type = nullptr; - RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); - - if (std::strncmp(op_type, "Mul", 4) == 0) { - // Check that Mul has inputs/output of type float - size_t num_inputs = 0; - size_t num_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.Node_GetNumInputs(node, &num_inputs)); - RETURN_IF_ERROR(ep->ort_api.Node_GetNumOutputs(node, &num_outputs)); - RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - - std::vector inputs(num_inputs); - std::vector outputs(num_outputs); - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - - std::array is_float = {false, false, false}; - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[0], is_float[0])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[1], is_float[1])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, outputs[0], is_float[2])); - if (!is_float[0] || !is_float[1] || !is_float[2]) { - continue; // Input or output is not of type float - } + for (const auto& node : nodes) { + auto op_type = node.GetOperatorType(); - supported_nodes.push_back(node); // Only support a single Mul for now. - break; - } - } + if (op_type != "Mul") { + // Check that Mul has inputs/output of type float + std::vector inputs = node.GetInputs(); + std::vector outputs = node.GetOutputs(); + + RETURN_IF(inputs.size() != 2 || outputs.size() != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; + std::array is_float = {false, false, false}; + IsFloatTensor(inputs[0], is_float[0]); + IsFloatTensor(inputs[1], is_float[1]); + IsFloatTensor(outputs[0], is_float[2]); + if (!is_float[0] || !is_float[1] || !is_float[2]) { + continue; // Input or output is not of type float + } - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), - supported_nodes.size(), &node_fusion_options)); + supported_nodes.push_back(node); // Only support a single Mul for now. + break; + } + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } return nullptr; } /*static*/ -OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, +OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** ort_graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { - ExampleEp* ep = static_cast(this_ptr); - const OrtApi& ort_api = ep->ort_api; - - if (count != 1) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); - } - - // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. - // So, this EP saves constant initializers so that they're available during inference, but an actual EP - // implementation could transfer the weights to device memory. - ep->SaveConstantInitializers(graphs[0]); - - size_t num_nodes = 0; - RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graphs[0], &num_nodes)); - - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], nodes.data(), nodes.size())); - - if (num_nodes != 1) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } + try { + if (count != 1) { + Ort::Status status("Expected to compile a single graph", ORT_EP_FAIL); + return status.release(); + } - const char* node_op_type = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &node_op_type)); + ExampleEp* ep = static_cast(this_ptr); - if (std::strncmp(node_op_type, "Mul", 4) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); - } + Ort::ConstGraph graph{ort_graphs[0]}; - // Now we know we're compiling a single Mul node. Create a computation kernel. - std::array node_inputs = {}; - std::array node_input_names = {}; + // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. + // So, this EP saves constant initializers so that they're available during inference, but an actual EP + // implementation could transfer the weights to device memory. + ep->SaveConstantInitializers(graph); - RETURN_IF_ERROR(ort_api.Node_GetInputs(nodes[0], node_inputs.data(), node_inputs.size())); - RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); - RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + std::vector nodes = graph.GetNodes(); + if (nodes.size() != 1) { + Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + return status.release(); + } - const char* ep_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); - if (std::strncmp(ep_name, "example_ep", 11) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); - } + auto node_op_type = nodes[0].GetOperatorType(); + if (node_op_type != "Mul") { + Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + return status.release(); + } - // Associate the name of the fused node with our MulKernel. - const char* fused_node_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); + // Now we know we're compiling a single Mul node. Create a computation kernel. + std::vector node_inputs = nodes[0].GetInputs(); + std::array node_input_names; + node_input_names[0] = node_inputs[0].GetName(); + node_input_names[1] = node_inputs[1].GetName(); + + Ort::ConstNode fused_node{fused_nodes[0]}; + auto ep_name = fused_node.GetEpName(); + if (ep_name != "example_ep") { + Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); + return status.release(); + } - ep->kernels_.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, + // Associate the name of the fused node with our MulKernel. + auto fused_node_name = fused_node.GetName(); + ep->kernels_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, ep->float_initializers_, node_input_names[0], node_input_names[1])); - // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); - node_compute_infos[0] = node_compute_info.release(); + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); - // Create EpContext nodes for the fused nodes we compiled. - if (ep->config_.enable_ep_context) { - assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), - gsl::span(ep_context_nodes, count))); + // Create EpContext nodes for the fused nodes we compiled. + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; @@ -375,69 +365,74 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, // cannot currently run the EPContext model. OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes) { - assert(fused_nodes.size() == ep_context_nodes.size()); + try { + assert(fused_nodes.size() == ep_context_nodes.size()); - // Helper to collect input or output names from an array of OrtValueInfo instances. - auto collect_input_output_names = [&](gsl::span value_infos, - std::vector& result) -> OrtStatus* { - size_t num_values = value_infos.size(); - std::vector value_names(num_values); + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](gsl::span value_infos, + std::vector& result) { + std::vector value_names; + value_names.reserve(value_infos.size()); - for (size_t i = 0; i < num_values; ++i) { - const OrtValueInfo* value_info = value_infos[i]; - RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); - } + for (const auto vi : value_infos) { + value_names.push_back(vi.GetName()); + } - result = std::move(value_names); - return nullptr; - }; - - // Create an "EPContext" node for every fused node. - for (size_t i = 0; i < fused_nodes.size(); ++i) { - const OrtNode* fused_node = fused_nodes[i]; - const char* fused_node_name = nullptr; - - RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); - - size_t num_fused_node_inputs = 0; - size_t num_fused_node_outputs = 0; - RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_fused_node_inputs)); - RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node, &num_fused_node_outputs)); - - std::vector fused_node_inputs(num_fused_node_inputs); - std::vector fused_node_outputs(num_fused_node_outputs); - RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, fused_node_inputs.data(), fused_node_inputs.size())); - RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, fused_node_outputs.data(), fused_node_outputs.size())); - - std::vector input_names; - std::vector output_names; - - RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); - RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); - - int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; - - // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. - std::array attributes = {}; - DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - - std::string ep_ctx = "binary_data"; - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", ep_ctx.c_str(), static_cast(ep_ctx.length()), - ORT_OP_ATTR_STRING, &attributes[0])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, static_cast(strlen(fused_node_name)), - ORT_OP_ATTR_STRING, &attributes[4])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), static_cast(this->name_.length()), - ORT_OP_ATTR_STRING, &attributes[5])); - - RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, - input_names.data(), input_names.size(), - output_names.data(), output_names.size(), - attributes.data(), attributes.size(), - &ep_context_nodes[i])); + result = std::move(value_names); + }; + + // Create an "EPContext" node for every fused node. + for (size_t i = 0; i < fused_nodes.size(); ++i) { + Ort::ConstNode fused_node{fused_nodes[i]}; + auto fused_node_name = fused_node.GetName(); + + std::vector fused_node_inputs = fused_node.GetInputs(); + std::vector fused_node_outputs = fused_node.GetOutputs(); + + std::vector input_names; + std::vector output_names; + + collect_input_output_names(fused_node_inputs, /*out*/ input_names); + collect_input_output_names(fused_node_outputs, /*out*/ output_names); + + int64_t is_main_context = (i == 0); + int64_t embed_mode = 1; + + // Create node attributes. The CreateNode() function copies the attributes. + std::array attributes = {}; + std::string ep_ctx = "binary_data"; + attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), + ORT_OP_ATTR_STRING); + + attributes[1] = Ort::OpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT); + attributes[2] = Ort::OpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT); + attributes[3] = Ort::OpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING); + attributes[4] = Ort::OpAttr("partition_name", fused_node_name.data(), static_cast(fused_node_name.size()), + ORT_OP_ATTR_STRING); + + attributes[5] = Ort::OpAttr("source", this->name_.data(), static_cast(this->name_.size()), + ORT_OP_ATTR_STRING); + + std::vector c_input_names; + std::transform(input_names.begin(), input_names.end(), std::back_inserter(c_input_names), + [](const std::string& s) { return s.c_str(); }); + std::vector c_output_names; + std::transform(output_names.begin(), output_names.end(), std::back_inserter(c_output_names), + [](const std::string& s) { return s.c_str(); }); + + OrtOpAttr** op_attrs = reinterpret_cast(attributes.data()); + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name.c_str(), + c_input_names.data(), c_input_names.size(), + c_output_names.data(), c_output_names.size(), + op_attrs, attributes.size(), + &ep_context_nodes[i])); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } return nullptr; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index fa6eb24c5cc04..279925a7ec3e1 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -54,7 +54,7 @@ class ExampleEp : public OrtEp, public ApiPtrs { OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); - OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph); + OrtStatus* SaveConstantInitializers(const OrtGraph* graph); ExampleEpFactory& factory_; std::string name_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index b6f982a422b6a..c14bdc1b52093 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + #include "ep_factory.h" // To make symbols visible on macOS/iOS @@ -21,6 +25,9 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const const OrtEpApi* ep_api = ort_api->GetEpApi(); const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + // Manual init for the C++ API + Ort::InitApi(ort_api); + // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(registration_name, ApiPtrs{*ort_api, *ep_api, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc index 549551931c647..263b4d208bd91 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc @@ -5,48 +5,33 @@ #include -OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, +OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& /* ort_api */, const OrtSessionOptions& session_options, const char* config_key, const std::string& default_val, /*out*/ std::string& config_val) { - int has_config = 0; - RETURN_IF_ERROR(ort_api.HasSessionConfigEntry(&session_options, config_key, &has_config)); - - if (has_config != 1) { - config_val = default_val; - return nullptr; + try { + Ort::ConstSessionOptions sess_opt{&session_options}; + config_val = sess_opt.GetConfigEntryOrDefault(config_key, default_val); + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); } - size_t size = 0; - RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, nullptr, &size)); - - config_val.resize(size); - RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, config_val.data(), &size)); - config_val.resize(size - 1); // remove the terminating '\0' - return nullptr; } -OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result) { +void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result) { result = false; - const OrtTypeInfo* type_info = nullptr; - RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(value_info, &type_info)); - - ONNXType onnx_type = ONNX_TYPE_UNKNOWN; - RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(type_info, &onnx_type)); + auto type_info = value_info.TypeInfo(); + ONNXType onnx_type = type_info.GetONNXType(); if (onnx_type != ONNX_TYPE_TENSOR) { - return nullptr; + return; } - const OrtTensorTypeAndShapeInfo* type_shape = nullptr; - RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(type_info, &type_shape)); - - ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + auto type_shape = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return nullptr; + return; } - result = true; - return nullptr; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index 99ebee9ff64de..e8c086d38a7cb 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -107,4 +107,4 @@ OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessio /*out*/ std::string& config_val); // Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. -OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result); +void IsFloatTensor(Ort::ConstValueInfo value_info, bool& result); diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc old mode 100755 new mode 100644 index 334be3e03b483..4b586e24c9bd3 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -10,6 +10,7 @@ #include "core/common/common.h" #include "core/framework/execution_provider.h" +#include "test/common/cuda_op_test_utils.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -102,6 +103,7 @@ void RunGatherBlockQuantized(const std::vector& data, const std::vector& output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, bool touch_on_device_data = false) { + (void)touch_on_device_data; CheckDataAndShape(data, data_shape, "data in RunGatherBlockQuantized"); CheckDataAndShape(indices, indices_shape, "indices in RunGatherBlockQuantized"); CheckDataAndShape(scales, scales_shape, "scales in RunGatherBlockQuantized"); @@ -127,12 +129,15 @@ void RunGatherBlockQuantized(const std::vector& data, test.AddOutput("output", output_shape, output); - if (touch_on_device_data) { - // test would need to see data on device - test.Run(expect_result, "", {kWebGpuExecutionProvider}, nullptr); + bool enable_cuda = HasCudaEnvironment(0); + std::vector> eps; + if (enable_cuda) { + eps.push_back(DefaultCudaExecutionProvider()); } else { - test.Run(expect_result, ""); + eps.push_back(DefaultCpuExecutionProvider()); } + + test.Run(expect_result, "", {}, nullptr, &eps); }; run_test(false); @@ -275,6 +280,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); @@ -289,6 +295,7 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); } +#endif template void Test_Fail_WithoutZeroPoints(int64_t gather_axis, @@ -317,6 +324,7 @@ void Test_Fail_WithoutZeroPoints(int64_t gather_axis, gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) { // Gather on axis other than 0 is not supported with uint8_t Test_Fail_WithoutZeroPoints(1, 2, 16); @@ -349,6 +357,7 @@ TEST(GatherBlockQuantizedOpTest, NotSupportedBits) { Test_Fail_WithZeroPoints(0, 2, 16, 6); Test_Fail_WithZeroPoints(0, 2, 16, 7); } +#endif template void Test_ShapeMismatch_WithZeroPoints() { @@ -377,11 +386,13 @@ void Test_ShapeMismatch_WithZeroPoints() { gather_axis, quantize_axis, block_size, bits, output, output_shape, false); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, ShapeMismatch) { Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); } +#endif template void Test_InvalidIndices_WithZeroPoints() { @@ -410,11 +421,13 @@ void Test_InvalidIndices_WithZeroPoints() { gather_axis, quantize_axis, block_size, bits, output, output_shape, false, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, InvalidIndices) { Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); } +#endif template void Test_GatherAxis0_WithZeroPoints(int bits = 4) { @@ -447,6 +460,7 @@ void Test_GatherAxis0_WithZeroPoints(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) { Test_GatherAxis0_WithZeroPoints(); Test_GatherAxis0_WithZeroPoints(); @@ -457,6 +471,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints) { Test_GatherAxis0_WithZeroPoints(); Test_GatherAxis0_WithZeroPoints(); } +#endif template void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) { @@ -490,6 +505,7 @@ void Test_GatherAxis0_WithZeroPoints_Uint8(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_4Bits) { Test_GatherAxis0_WithZeroPoints_Uint8(); Test_GatherAxis0_WithZeroPoints_Uint8(); @@ -499,6 +515,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0WithZeroPoints_8Bits) { Test_GatherAxis0_WithZeroPoints_Uint8(8); Test_GatherAxis0_WithZeroPoints_Uint8(8); } +#endif template void Test_GatherAxis0_NoZeroPoints(int bits = 4) { @@ -533,6 +550,7 @@ void Test_GatherAxis0_NoZeroPoints(int bits = 4) { -3, -1, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) { Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); @@ -551,6 +569,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) { Test_GatherAxis0_NoZeroPoints(8); Test_GatherAxis0_NoZeroPoints(8); } +#endif template void Test_GatherAxis1_WithZeroPoints() { @@ -585,6 +604,7 @@ void Test_GatherAxis1_WithZeroPoints() { -2, -2, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis1) { Test_GatherAxis1_WithZeroPoints(); Test_GatherAxis1_WithZeroPoints(); @@ -595,6 +615,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis1) { Test_GatherAxis1_WithZeroPoints(); Test_GatherAxis1_WithZeroPoints(); } +#endif template void Test_GatherAxis2_WithZeroPoints() { @@ -629,6 +650,7 @@ void Test_GatherAxis2_WithZeroPoints() { -1, -3, block_size, bits, output, output_shape, true); } +#ifndef USE_CUDA TEST(GatherBlockQuantizedOpTest, GatherAxis2) { Test_GatherAxis2_WithZeroPoints(); Test_GatherAxis2_WithZeroPoints(); @@ -639,6 +661,7 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis2) { Test_GatherAxis2_WithZeroPoints(); Test_GatherAxis2_WithZeroPoints(); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 4c3f9e8dd4dbd..7213937d0ef11 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -345,7 +345,7 @@ void TestMatMulNBitsTyped(std::optional abs_error = std::nullopt, #if !defined(USE_OPENVINO) -TEST(MatMulNBits, Float32_Accuracy0) { +TEST(MatMulNBits, Float32_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -372,7 +372,7 @@ TEST(MatMulNBits, Float32_Accuracy0) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float32_Accuracy1) { +TEST(MatMulNBits, Float32_4b_Accuracy1) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -383,7 +383,7 @@ TEST(MatMulNBits, Float32_Accuracy1) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float32_Accuracy4) { +TEST(MatMulNBits, Float32_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -415,7 +415,7 @@ TEST(MatMulNBits, Float32_Accuracy4) { #if !defined(USE_DML) // Actual and expected difference is over 0.01 with DmlExecutionProvider. // Skip the tests instead of raising the tolerance to make is pass. -TEST(MatMulNBits, Float16_Accuracy2) { +TEST(MatMulNBits, Float16_4b_Accuracy2) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -442,7 +442,7 @@ TEST(MatMulNBits, Float16_Accuracy2) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy0) { +TEST(MatMulNBits, Float16_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -453,7 +453,7 @@ TEST(MatMulNBits, Float16_Accuracy0) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, Float16_Accuracy4) { +TEST(MatMulNBits, Float16_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -483,7 +483,7 @@ TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); } -TEST(MatMulNBits, LegacyShape) { +TEST(MatMulNBits, LegacyShape_4b) { constexpr bool legacy_shape = true; TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index a7df3b7bbec54..c60abbc278962 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -25,7 +25,9 @@ #include "core/session/ort_env.h" #include "core/util/qmath.h" -#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU) +#if ((defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64)) && \ + !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || \ + defined(USE_CUDA) || defined(USE_WEBGPU) extern std::unique_ptr ort_env; @@ -275,6 +277,7 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { GTEST_SKIP() << "Skipping test on Android x86_64 (emulator)."; #endif TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 42f62981cb52b..0690b8894eb7a 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -9,17 +9,19 @@ namespace onnxruntime { namespace test { +// Note: QMoE CPU implementation now always applies softmax normalization to top-k selected experts +// regardless of the normalize_routing_weights parameter value for mathematical correctness. + #ifndef ENABLE_TRAINING static void RunMoETest(const std::vector& input, const std::vector& router_probs, const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, int inter_size, std::string activation_type, - int normalize_routing_weights = 0, int top_k = 1, bool use_float16 = false) { + int normalize_routing_weights = 1, int top_k = 1, bool use_float16 = false) { constexpr int min_cuda_arch = 700; - constexpr int max_cuda_arch = 900; - bool enable_cuda = HasCudaEnvironment(min_cuda_arch) && !NeedSkipIfCudaArchGreaterEqualThan(max_cuda_arch); + bool enable_cuda = HasCudaEnvironment(min_cuda_arch); if (enable_cuda) { OpTester tester("MoE", 1, onnxruntime::kMSDomain); tester.AddAttribute("k", static_cast(top_k)); @@ -28,8 +30,8 @@ static void RunMoETest(const std::vector& input, const std::vector std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_experts_bias_dims = {num_experts, inter_size}; std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; @@ -91,44 +93,97 @@ static void RunQMoETest(const std::vector& input, const std::vector& fc3_experts_weights, const std::vector& fc1_scales, const std::vector& fc2_scales, const std::vector& fc3_scales, const std::vector& output_data, int num_rows, int num_experts, int hidden_size, - int inter_size, std::string activation_type, int normalize_routing_weights = 0, int top_k = 1) { + int inter_size, std::string activation_type, int normalize_routing_weights = 1, int top_k = 1, int expert_weight_bits = 4) { constexpr int min_cuda_arch = 700; - constexpr int max_cuda_arch = 900; - bool enable_cuda = HasCudaEnvironment(min_cuda_arch) && !NeedSkipIfCudaArchGreaterEqualThan(max_cuda_arch); + // Test CUDA execution provider + bool enable_cuda = HasCudaEnvironment(min_cuda_arch); if (enable_cuda) { - OpTester tester("QMoE", 1, onnxruntime::kMSDomain); - tester.AddAttribute("k", static_cast(top_k)); - tester.AddAttribute("activation_type", activation_type); - tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + OpTester cuda_tester("QMoE", 1, onnxruntime::kMSDomain); + cuda_tester.AddAttribute("k", static_cast(top_k)); + cuda_tester.AddAttribute("activation_type", activation_type); + cuda_tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; - std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; - std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_scales_dims = {num_experts, inter_size}; std::vector fc2_scales_dims = {num_experts, hidden_size}; std::vector fc3_scales_dims = fc1_scales_dims; std::vector output_dims = {num_rows, hidden_size}; - tester.AddInput("input", input_dims, ToFloat16(input)); - tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cuda_tester.AddInput("input", input_dims, ToFloat16(input)); + cuda_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); - tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); - tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); - tester.AddOptionalInputEdge(); // fc1_experts_bias - tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); - tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); - tester.AddOptionalInputEdge(); // fc2_experts_bias - tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); - tester.AddInput("fc3_scales", fc3_scales_dims, ToFloat16(fc3_scales)); - tester.AddOutput("output", output_dims, ToFloat16(output_data)); - tester.SetOutputTolerance(0.005f); + cuda_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cuda_tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); + cuda_tester.AddOptionalInputEdge(); // fc1_experts_bias + cuda_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cuda_tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); + cuda_tester.AddOptionalInputEdge(); // fc2_experts_bias - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + // Only add FC3 inputs if fc3_experts_weights is not empty + if (!fc3_experts_weights.empty()) { + cuda_tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + cuda_tester.AddInput("fc3_scales", fc3_scales_dims, ToFloat16(fc3_scales)); + } else { + cuda_tester.AddOptionalInputEdge(); // fc3_experts_weights + cuda_tester.AddOptionalInputEdge(); // fc3_scales + } + cuda_tester.AddOptionalInputEdge(); // fc3_experts_bias + cuda_tester.AddOutput("output", output_dims, ToFloat16(output_data)); + cuda_tester.SetOutputTolerance(0.005f); + + std::vector> cuda_execution_providers; + cuda_execution_providers.push_back(DefaultCudaExecutionProvider()); + cuda_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cuda_execution_providers); + } + + // Test CPU execution provider (always available) + // Skip CPU test if FC3 weights are provided since CPU doesn't support FC3 + if (fc3_experts_weights.empty()) { + // Ensure CPU EP is available before running CPU tests + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + return; // Skip CPU test if CPU EP is not available + } + + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", static_cast(top_k)); + cpu_tester.AddAttribute("activation_type", activation_type); + cpu_tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + cpu_tester.AddAttribute("expert_weight_bits", static_cast(expert_weight_bits)); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, expert_weight_bits == 4 ? inter_size / 2 : inter_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, expert_weight_bits == 4 ? hidden_size / 2 : hidden_size}; + std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + + // CPU doesn't support FC3, so always skip it + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (skip FC3 for CPU - not implemented) + cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float, not MLFloat16) + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output_data)); + cpu_tester.SetOutputTolerance(0.01f); // Slightly higher tolerance for CPU vs CUDA differences + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); } } @@ -307,7 +362,7 @@ TEST(MoETest, MoETest_Gelu) { 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "gelu"); + output, num_rows, num_experts, hidden_size, inter_size, "gelu", 0); } TEST(MoETest, MoETest_Relu) { @@ -485,7 +540,7 @@ TEST(MoETest, MoETest_Relu) { 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, - output, num_rows, num_experts, hidden_size, inter_size, "relu"); + output, num_rows, num_experts, hidden_size, inter_size, "relu", 0); } TEST(MoETest, MoETest_Mixtral) { @@ -1268,8 +1323,373 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { RunQMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, fc1_scales, fc2_scales, fc3_scales, output, num_rows, num_experts, hidden_size, inter_size, "silu", 1, /*normalize_routing_weights*/ - 2 /*top_k*/); + 2, /*top_k*/ + 4 /*expert_weight_bits*/); } + +// CPU-specific QMoE tests +TEST(MoETest, QMoETest_CPU_Int4_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + int num_rows = 2; + int num_experts = 2; + int hidden_size = 32; + int inter_size = 32; + + const std::vector input = { + -0.5f, 0.2f, 1.1f, -0.3f, 0.8f, -0.1f, 0.4f, -0.7f, 0.9f, -0.2f, 0.6f, 0.1f, -0.4f, 0.3f, -0.8f, 0.7f, + 0.2f, -0.5f, 0.1f, 0.9f, -0.3f, 0.6f, -0.1f, 0.4f, -0.7f, 0.8f, 0.3f, -0.2f, 0.5f, 0.1f, -0.6f, 0.9f, + 0.1f, 0.7f, -0.4f, 0.2f, 0.8f, -0.3f, 0.5f, -0.1f, 0.6f, 0.4f, -0.7f, 0.3f, 0.9f, -0.2f, 0.1f, 0.8f, + -0.5f, 0.6f, 0.3f, -0.1f, 0.4f, 0.7f, -0.8f, 0.2f, 0.9f, 0.1f, -0.3f, 0.5f, 0.6f, -0.4f, 0.8f, 0.2f}; + + const std::vector router_probs = {0.3f, 0.7f, 0.6f, 0.4f}; + + // Generate simple test weights for 4-bit symmetric quantization with SwiGLU + // Use 0x88 which unpacks to 8,8 -> 0,0 in signed form (8-8=0) for zero weights + // For SwiGLU: FC1 outputs 2*inter_size (gate + linear), FC2 takes inter_size input + std::vector fc1_experts_weights(num_experts * hidden_size * inter_size, 0x88); // 2*inter_size for SwiGLU, packed into /2 + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size / 2, 0x88); // 8,8 values to produce zero output + std::vector fc3_experts_weights; // Empty for CPU (FC3 not supported) + + std::vector fc1_scales(num_experts * inter_size * 2, 0.01f); // 2x for SwiGLU (gate + linear) + std::vector fc2_scales(num_experts * hidden_size, 0.01f); // Smaller scale factor + std::vector fc3_scales; + + // With zero weights (0x88 -> 8,8 -> 0,0 signed), the implementation will produce all zero outputs + std::vector output(num_rows * hidden_size, 0.0f); + + // Test CPU execution provider ONLY (don't use RunQMoETest which tests both CUDA and CPU) + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 4); // Test 4-bit quantization + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2}; // SwiGLU: 2*inter_size output, 4-bit packed + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (skip FC3 for CPU) + cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float for CPU) + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + + // When using 0x88 for 4-bit quantized weights with the current implementation, + // all dequantized values should be 0.0f (8-8=0), and thus output should be all zeros + std::vector expected_output(num_rows * hidden_size, 0.0f); + + cpu_tester.AddOutput("output", output_dims, ToFloat16(expected_output)); + cpu_tester.SetOutputTolerance(0.05f); // Small tolerance for numerical differences + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_Int8_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + // Test CPU implementation with 8-bit quantization - CPU ONLY + int num_rows = 1; + int num_experts = 2; + int hidden_size = 16; + int inter_size = 16; + + const std::vector input = { + 0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f, 0.9f, -1.0f, 1.1f, -1.2f, 1.3f, -1.4f, 1.5f, -1.6f}; + + const std::vector router_probs = {0.4f, 0.6f}; + + // For 8-bit symmetric quantization with SwiGLU + // Use quantized weights at zero point for zero outputs (128 = 0 in signed) + std::vector fc1_experts_weights(num_experts * 2 * inter_size * hidden_size, 128); // 2*inter_size for SwiGLU, no packing for 8-bit + std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 128); // 128 = 0 in signed + std::vector fc3_experts_weights; // Empty for CPU + + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); // 2x for SwiGLU + std::vector fc2_scales(num_experts * hidden_size, 0.1f); + std::vector fc3_scales; + + // Expected output should be zero since we're using zero weights (128-128=0) + std::vector output(num_rows * hidden_size, 0.0f); + + // Test with different attributes for 8-bit + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 1); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Always use 1 - softmax normalization always applied + cpu_tester.AddAttribute("expert_weight_bits", 8); // Test 8-bit quantization + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // SwiGLU: 2*inter_size output, 8-bit no packing + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float, not MLFloat16 + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (skip FC3 for CPU) + cpu_tester.AddOptionalInputEdge(); // fc3_scales (use float, not MLFloat16) + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output)); + cpu_tester.SetOutputTolerance(0.05f); // Small tolerance since we expect near-zero output + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_FC3_Error) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + // Test that CPU throws error when FC3 gating is provided - CPU ONLY + int num_rows = 1; + int num_experts = 2; + int hidden_size = 8; + int inter_size = 8; + + const std::vector input = {0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f}; + const std::vector router_probs = {0.5f, 0.5f}; + + // Using new layout: fc1 has fused swiglu doubling (2*inter_size) and 4-bit pack_size=2 so hidden_size packed dimension is hidden_size/2 + const int pack_size = 2; // for 4-bit + const int fc1_inter_size = 2 * inter_size; // swiglu fused + std::vector fc1_experts_weights(num_experts * fc1_inter_size * (hidden_size / pack_size), 0x01); + std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0x10); + std::vector fc3_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0x21); // FC3 provided + + std::vector fc1_scales(num_experts * fc1_inter_size, 0.1f); + std::vector fc2_scales(num_experts * hidden_size, 0.05f); + std::vector fc3_scales(num_experts * inter_size, 0.08f); // FC3 scales provided + + // Test CPU execution provider ONLY (designed to test CPU-specific error handling) + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 1); + cpu_tester.AddAttribute("activation_type", "swiglu"); // CPU only supports swiglu + cpu_tester.AddAttribute("normalize_routing_weights", 1); // Use 1 for consistency, though this test focuses on FC3 error + cpu_tester.AddAttribute("expert_weight_bits", 4); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, fc1_inter_size, hidden_size / pack_size}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + std::vector fc3_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + std::vector fc1_scales_dims = {num_experts, fc1_inter_size}; + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector fc3_scales_dims = {num_experts, inter_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); // FC3 provided! + cpu_tester.AddInput("fc3_scales", fc3_scales_dims, fc3_scales); // Use float for CPU + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + + std::vector dummy_output(num_rows * hidden_size, 0.0f); + cpu_tester.AddOutput("output", output_dims, ToFloat16(dummy_output)); + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + + // Expect this to fail with FC3 not implemented error + cpu_tester.Run(OpTester::ExpectResult::kExpectFailure, "FC3 gating is not yet implemented", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + + // Test CPU implementation with 4-bit quantization and SwiGLU activation + int num_rows = 2; + int num_experts = 2; + int hidden_size = 16; + int inter_size = 16; + + const std::vector input = { + 0.1f, -0.2f, 0.3f, -0.4f, 0.5f, -0.6f, 0.7f, -0.8f, 0.9f, -1.0f, 1.1f, -1.2f, 1.3f, -1.4f, 1.5f, -1.6f, + 0.2f, -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, -0.9f, 1.0f, -1.1f, 1.2f, -1.3f, 1.4f, -1.5f, 1.6f, -1.7f}; + + const std::vector router_probs = {0.6f, 0.4f, 0.3f, 0.7f}; + + // For SwiGLU, FC1 weights need to be 2x inter_size (concatenated linear + gate weights) + // 4-bit: each uint8 stores 2 weights, so we need (hidden_size * inter_size * 2) / 2 uint8s per expert + const int fc1_weight_size_per_expert = hidden_size * inter_size * 2 / 2; // For 4-bit SwiGLU + const int fc2_weight_size_per_expert = inter_size * hidden_size / 2; // For 4-bit FC2 + + // Generate test weights for symmetric quantization (zero point is 8 for 4-bit) + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0x88); // 8,8 -> 0,0 signed weights + std::vector fc3_experts_weights; // Empty for SwiGLU (gate weights concatenated with FC1) + + // Scales: for SwiGLU, FC1 has 2*inter_size outputs (linear + gate) + std::vector fc1_scales(num_experts * inter_size * 2, 0.05f); // Small scale for reasonable outputs + std::vector fc2_scales(num_experts * hidden_size, 0.05f); + std::vector fc3_scales; + + // For SwiGLU with zero weights (0x88 -> 8,8 -> 0,0 signed): + // Gate output = 0, Linear output = 0 + // SwiGLU = gate * sigmoid(gate) * (linear + 1) = 0 * sigmoid(0) * (0 + 1) = 0 * 0.5 * 1 = 0 + // So output should be zero + std::vector output(num_rows * hidden_size, 0.0f); + + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); // Test SwiGLU activation + cpu_tester.AddAttribute("normalize_routing_weights", 1); + cpu_tester.AddAttribute("expert_weight_bits", 4); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, 2 * inter_size, hidden_size / 2}; + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU (linear + gate) + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights (empty for SwiGLU) + cpu_tester.AddOptionalInputEdge(); // fc3_scales + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output)); + cpu_tester.SetOutputTolerance(0.02f); // Higher tolerance for SwiGLU nonlinearity + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + +TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test CPU implementation with 8-bit quantization and SwiGLU activation + int num_rows = 1; + int num_experts = 2; + int hidden_size = 8; + int inter_size = 8; + + const std::vector input = {0.2f, -0.3f, 0.4f, -0.5f, 0.6f, -0.7f, 0.8f, -0.9f}; + const std::vector router_probs = {0.0f, 0.0f}; + + // For SwiGLU with 8-bit symmetric quantization: FC1 weights are 2x inter_size (concatenated linear + gate weights) + const int fc1_weight_size_per_expert = hidden_size * inter_size * 2; // For 8-bit SwiGLU + const int fc2_weight_size_per_expert = inter_size * hidden_size; // For 8-bit FC2 + + // Generate test weights at zero (for symmetric quantization storage format: uint8 with zero point 128) + // Fill with 128 so dequantized value (val - 128) == 0 => zero output + std::vector fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 128); + std::vector fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 128); + std::vector fc3_experts_weights; // Empty for SwiGLU + + // Scales: for SwiGLU, FC1 has 2*inter_size outputs + std::vector fc1_scales(num_experts * inter_size * 2, 0.1f); + std::vector fc2_scales(num_experts * hidden_size, 0.1f); + std::vector fc3_scales; + + std::vector output(num_rows * hidden_size, 0.0f); + + OpTester cpu_tester("QMoE", 1, onnxruntime::kMSDomain); + cpu_tester.AddAttribute("k", 2); + cpu_tester.AddAttribute("activation_type", "swiglu"); // Test SwiGLU activation + cpu_tester.AddAttribute("normalize_routing_weights", 1); + cpu_tester.AddAttribute("expert_weight_bits", 8); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, inter_size * 2, hidden_size}; // 8-bit SwiGLU: explicit 2x + std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size}; + std::vector fc1_scales_dims = {num_experts, inter_size * 2}; // 2x for SwiGLU + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + cpu_tester.AddInput("input", input_dims, ToFloat16(input)); + cpu_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cpu_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cpu_tester.AddInput("fc1_scales", fc1_scales_dims, fc1_scales); + cpu_tester.AddOptionalInputEdge(); // fc1_experts_bias + cpu_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cpu_tester.AddInput("fc2_scales", fc2_scales_dims, fc2_scales); + cpu_tester.AddOptionalInputEdge(); // fc2_experts_bias + cpu_tester.AddOptionalInputEdge(); // fc3_experts_weights + cpu_tester.AddOptionalInputEdge(); // fc3_scales + cpu_tester.AddOptionalInputEdge(); // fc3_experts_bias + cpu_tester.AddOutput("output", output_dims, ToFloat16(output)); + cpu_tester.SetOutputTolerance(0.02f); + + std::vector> cpu_execution_providers; + cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); + cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif +} + #endif } // namespace test diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 513097aaf7ade..7e6d157799d86 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -100,30 +100,20 @@ TEST(EpGraphTest, GetAttributeByName) { // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not // have statically computable default values, so will not be filled in by Graph::Resolve(). const OrtGraph& ort_graph = test_graph->GetOrtGraph(); - const OrtApi& ort_api = Ort::GetApi(); - - size_t num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - ASSERT_EQ(num_nodes, 1); + Ort::ConstGraph graph{&ort_graph}; - std::vector nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + auto nodes = graph.GetNodes(); + ASSERT_EQ(nodes.size(), 1); - const OrtNode* conv_node = nodes[0]; - const char* op_type = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type)); - ASSERT_STREQ(op_type, "Conv"); + auto conv_node = nodes[0]; + auto op_type = conv_node.GetOperatorType(); + ASSERT_EQ(op_type, "Conv"); - size_t num_attrs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs)); - ASSERT_EQ(num_attrs, 2); + auto attrs = conv_node.GetAttributes(); + ASSERT_EQ(attrs.size(), 2); - std::vector attrs(num_attrs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size())); - for (const OrtOpAttr* attr : attrs) { - const char* attr_name_cstr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr)); - std::string_view attr_name = attr_name_cstr; + for (const auto& attr : attrs) { + auto attr_name = attr.GetName(); ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set } @@ -131,9 +121,8 @@ TEST(EpGraphTest, GetAttributeByName) { // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. // { - const OrtOpAttr* attr = nullptr; - Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)}; - ASSERT_TRUE(status.IsOK()); + Ort::ConstOpAttr attr; + auto status = conv_node.GetAttributeByName("dilations", attr); ASSERT_EQ(attr, nullptr); } @@ -141,8 +130,8 @@ TEST(EpGraphTest, GetAttributeByName) { // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. // { - const OrtOpAttr* attr = nullptr; - Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)}; + Ort::ConstOpAttr attr; + Ort::Status status = conv_node.GetAttributeByName("_does_not_exist_", attr); ASSERT_FALSE(status.IsOK()); ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); ASSERT_EQ(attr, nullptr); @@ -152,23 +141,14 @@ TEST(EpGraphTest, GetAttributeByName) { // Test 3: Get attribute that is known to be set. // { - const OrtOpAttr* attr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr)); + Ort::ConstOpAttr attr; + ASSERT_ORTSTATUS_OK(conv_node.GetAttributeByName("auto_pad", attr)); ASSERT_NE(attr, nullptr); - OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type)); - ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING); - + OrtOpAttrType type = attr.GetType(); + ASSERT_EQ(ORT_OP_ATTR_STRING, type); std::string auto_pad_val; - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - size_t total_attr_bytes = 0; - Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)}; - auto_pad_val.resize(total_attr_bytes); - - ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes, - &total_attr_bytes)); + ASSERT_ORTSTATUS_OK(attr.GetValue(auto_pad_val)); ASSERT_EQ(auto_pad_val, "NOTSET"); } } @@ -229,14 +209,10 @@ TEST(EpGraphTest, SerializeToProto_InputModelHasExternalIni) { std::string ext_ini_file_path = "conv_qdq_ext_ini_serialized.bin"; std::filesystem::remove(ext_ini_file_path); std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); - auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* /* value_info */, const void* data, size_t bytes, bool& is_external, std::string& location, int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, - // node consumers, etc. - (void)value_info; - if (bytes <= 127) { is_external = false; // Keep small initializers stored inside the TensorProto. return Ort::Status{nullptr}; @@ -442,13 +418,13 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } for (size_t i = 0; i < api_num_initializers; ++i) { - const OrtValue* ort_value = nullptr; - const void* ort_value_data = nullptr; - const char* value_name = nullptr; + std::string value_name; + Ort::ConstValue ort_value; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_initializers[i], &value_name)); - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_initializers[i], &ort_value)); - ASSERT_ORTSTATUS_OK(ort_api.GetTensorData(ort_value, &ort_value_data)); + Ort::ConstValueInfo vi(api_initializers[i]); + value_name = vi.GetName(); + ASSERT_ORTSTATUS_OK(vi.GetInitializer(ort_value)); + const void* ort_value_data = ort_value.GetTensorRawData(); auto iter = tensor_proto_map.find(value_name); ASSERT_NE(iter, tensor_proto_map.end()); @@ -723,25 +699,21 @@ static void CheckValueInfoConsumers(const GraphViewer& graph_viewer, const OrtVa static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, const ONNX_NAMESPACE::TensorProto* tensor_proto, const GraphViewer& graph_viewer) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* api_initializer_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); - ASSERT_NE(api_initializer_name, nullptr); + Ort::ConstValueInfo vi(api_value_info); + std::string api_initializer_name = vi.GetName(); // Check external initializer info (if any). - OrtExternalInitializerInfo* api_ext_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetExternalInitializerInfo(api_value_info, &api_ext_info)); - DeferOrtRelease defer_release_info(&api_ext_info, ort_api.ReleaseExternalInitializerInfo); + Ort::ExternalInitializerInfo api_ext_info{nullptr}; + auto external_status = vi.GetExternalInitializerInfo(api_ext_info); std::unique_ptr ext_info = nullptr; bool has_ext_info = graph_viewer.GetGraph().GetExternalInitializerInfo(api_initializer_name, ext_info, true); if (has_ext_info) { ASSERT_NE(api_ext_info, nullptr); - const ORTCHAR_T* api_ext_file_path = ort_api.ExternalInitializerInfo_GetFilePath(api_ext_info); - int64_t api_ext_file_offset = ort_api.ExternalInitializerInfo_GetFileOffset(api_ext_info); - size_t api_ext_byte_size = ort_api.ExternalInitializerInfo_GetByteSize(api_ext_info); + const std::basic_string api_ext_file_path = api_ext_info.GetFilePath(); + int64_t api_ext_file_offset = api_ext_info.GetFileOffset(); + size_t api_ext_byte_size = api_ext_info.GetByteSize(); ASSERT_EQ(PathString(api_ext_file_path), ext_info->GetRelPath()); ASSERT_EQ(api_ext_file_offset, static_cast(ext_info->GetOffset())); @@ -751,61 +723,49 @@ static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, ASSERT_FALSE(utils::HasExternalDataInFile(*tensor_proto)); } - const OrtValue* api_initializer_value = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); + Ort::ConstValue api_initializer_value; + ASSERT_ORTSTATUS_OK(vi.GetInitializer(api_initializer_value)); ASSERT_NE(api_initializer_value, nullptr); // Check initializer type. const ONNX_NAMESPACE::TypeProto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); auto type_info = OrtTypeInfo::FromTypeProto(type_proto); - const OrtTypeInfo* api_type_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(api_value_info, &api_type_info)); + Ort::ConstTypeInfo api_type_info = vi.TypeInfo(); CheckTypeInfo(api_type_info, type_info.get()); } -static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, +static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, const InitializedTensorSet& initializer_tensor_protos, const GraphViewer& graph_viewer) { - const OrtApi& ort_api = Ort::GetApi(); - for (size_t i = 0; i < initializer_value_infos.size(); i++) { - const OrtValueInfo* api_value_info = initializer_value_infos[i]; - - const char* api_initializer_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); - ASSERT_NE(api_initializer_name, nullptr); + Ort::ConstValueInfo vi(initializer_value_infos[i]); + std::string api_initializer_name = vi.GetName(); auto tensor_proto_iter = initializer_tensor_protos.find(api_initializer_name); ASSERT_NE(tensor_proto_iter, initializer_tensor_protos.end()); const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_proto_iter->second; ASSERT_NE(tensor_proto, nullptr); - - CheckInitializerValueInfo(api_value_info, tensor_proto, graph_viewer); + CheckInitializerValueInfo(vi, tensor_proto, graph_viewer); } } // Checks that the OrtValueInfos obtained from the public C API are "equivalent" to the NodeArgs // in the original graph. -static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, +static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, gsl::span node_args) { ASSERT_EQ(value_infos.size(), node_args.size()); - const OrtApi& ort_api = Ort::GetApi(); const auto& graph_viewer_inputs = graph_viewer.GetInputsIncludingInitializers(); const auto& graph_viewer_outputs = graph_viewer.GetOutputs(); for (size_t i = 0; i < value_infos.size(); i++) { const NodeArg* node_arg = node_args[i]; - const OrtValueInfo* value_info = value_infos[i]; + Ort::ConstValueInfo vi(value_infos[i]); if (node_arg->Exists()) { const auto& value_name = node_arg->Name(); - - ASSERT_NE(value_info, nullptr); - - const char* api_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(value_info, &api_name)); + std::string api_name = vi.GetName(); ASSERT_EQ(std::string(api_name), value_name); bool is_graph_input = std::any_of(graph_viewer_inputs.begin(), graph_viewer_inputs.end(), @@ -825,64 +785,52 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::spanName()); - bool api_is_outer_scope = false; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsFromOuterScope(value_info, &api_is_outer_scope)); + bool api_is_outer_scope = vi.IsFromOuterScope(); ASSERT_EQ(api_is_outer_scope, is_outer_scope); - bool api_is_const_initializer = false; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsConstantInitializer(value_info, &api_is_const_initializer)); + bool api_is_const_initializer = vi.IsConstantInitializer(); ASSERT_EQ(api_is_const_initializer, is_const_initializer); if (is_const_initializer || api_is_opt_graph_input) { - CheckInitializerValueInfo(value_info, initializer, graph_viewer); + CheckInitializerValueInfo(vi, initializer, graph_viewer); } else { auto node_arg_type_info = OrtTypeInfo::FromTypeProto(*node_arg->TypeAsProto()); - const OrtTypeInfo* api_type_info = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(value_info, &api_type_info)); + Ort::ConstTypeInfo api_type_info = vi.TypeInfo(); CheckTypeInfo(api_type_info, node_arg_type_info.get()); } - CheckValueInfoProducer(graph_viewer, value_info, node_arg); - CheckValueInfoConsumers(graph_viewer, value_info, node_arg); + CheckValueInfoProducer(graph_viewer, vi, node_arg); + CheckValueInfoConsumers(graph_viewer, vi, node_arg); } else { - ASSERT_EQ(value_info, nullptr); // A missing optional input has a null OrtValueInfo. + ASSERT_EQ(vi, nullptr); // A missing optional input has a null OrtValueInfo. } } } // Checks the Graph_GetSubgraph C API static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { - const OrtApi& ort_api = Ort::GetApi(); - + Ort::ConstGraph ort_graph{&api_graph}; // Get all the nodes - size_t num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + std::vector nodes = ort_graph.GetNodes(); // Select a half of nodes to create a OrtGraph size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); - std::vector selected_nodes(num_selected_nodes); + std::vector selected_nodes(num_selected_nodes); for (size_t i = 0; i < num_selected_nodes; i++) { selected_nodes[i] = nodes[i]; } - OrtGraph* sub_graph; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. @@ -892,31 +840,25 @@ static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - const char* graph_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + auto graph_name = ort_graph.GetName(); std::string name = graph_name; name += "_half.onnx"; // Dump the graph for debugging // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); // model_proto->SerializeToOstream(&dump); - - ort_api.ReleaseGraph(sub_graph); } // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { - const OrtApi& ort_api = Ort::GetApi(); - + auto ort_cxx_graph = Ort::ConstGraph(&api_graph); // Check the path to model. const std::filesystem::path& model_path = graph_viewer.ModelPath(); - const ORTCHAR_T* api_model_path = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path)); + const auto api_model_path = ort_cxx_graph.GetModelPath(); ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str())); // Check the model metadata Ort::AllocatorWithDefaultOptions default_allocator; - auto ort_cxx_graph = Ort::ConstGraph(&api_graph); auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata(); auto& model = graph_viewer.GetGraph().GetModel(); ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0); @@ -933,42 +875,30 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check graph inputs. const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); - size_t api_num_graph_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInputs(&api_graph, &api_num_graph_inputs)); - ASSERT_EQ(api_num_graph_inputs, graph_input_node_args.size()); + std::vector api_graph_inputs = ort_cxx_graph.GetInputs(); + ASSERT_EQ(api_graph_inputs.size(), graph_input_node_args.size()); - std::vector api_graph_inputs(api_num_graph_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInputs(&api_graph, api_graph_inputs.data(), api_graph_inputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_inputs, graph_input_node_args); // Check graph outputs. const auto& graph_output_node_args = graph_viewer.GetOutputs(); - size_t api_num_graph_outputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumOutputs(&api_graph, &api_num_graph_outputs)); - ASSERT_EQ(api_num_graph_outputs, graph_output_node_args.size()); + std::vector api_graph_outputs = ort_cxx_graph.GetOutputs(); + ASSERT_EQ(api_graph_outputs.size(), graph_output_node_args.size()); - std::vector api_graph_outputs(api_num_graph_outputs); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetOutputs(&api_graph, api_graph_outputs.data(), api_graph_outputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_outputs, graph_output_node_args); // Check graph initializers const auto& graph_initializers = graph_viewer.GetAllInitializedTensors(); - size_t api_num_initializers = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInitializers(&api_graph, &api_num_initializers)); - ASSERT_EQ(api_num_initializers, graph_initializers.size()); - - std::vector api_initializers(api_num_initializers); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, api_initializers.data(), api_initializers.size())); + std::vector api_initializers = ort_cxx_graph.GetInitializers(); + ASSERT_EQ(api_initializers.size(), graph_initializers.size()); CheckInitializerValueInfosCApi(api_initializers, graph_initializers, graph_viewer); // Check if it has a parent node. const Node* parent_node = graph_viewer.ParentNode(); const bool has_parent_node = parent_node != nullptr; - const OrtNode* api_parent_node = nullptr; - - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetParentNode(&api_graph, &api_parent_node)); + Ort::ConstNode api_parent_node = ort_cxx_graph.GetParentNode(); const bool api_has_parent_node = api_parent_node != nullptr; ASSERT_EQ(api_has_parent_node, has_parent_node); @@ -977,79 +907,56 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check all nodes. - size_t api_num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &api_num_nodes)); - ASSERT_EQ(api_num_nodes, graph_viewer.NumberOfNodes()); - - std::vector api_nodes(api_num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, api_nodes.data(), api_nodes.size())); + std::vector api_nodes = ort_cxx_graph.GetNodes(); + ASSERT_EQ(api_nodes.size(), graph_viewer.NumberOfNodes()); std::vector node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); - for (size_t node_idx = 0; node_idx < api_num_nodes; node_idx++) { + for (size_t node_idx = 0; node_idx < api_nodes.size(); node_idx++) { // Check basic node properties. const Node* node = graph_viewer.GetNode(node_indices[node_idx]); - const OrtNode* api_node = api_nodes[node_idx]; + Ort::ConstNode api_node = api_nodes[node_idx]; CheckNode(node, api_node); - int api_since_version = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSinceVersion(api_node, &api_since_version)); + const int api_since_version = api_node.GetSinceVersion(); ASSERT_EQ(api_since_version, node->SinceVersion()); // Check node inputs const auto input_node_args = node->InputDefs(); - size_t api_node_num_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumInputs(api_node, &api_node_num_inputs)); - ASSERT_EQ(api_node_num_inputs, input_node_args.size()); - - std::vector api_node_inputs(api_node_num_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetInputs(api_node, api_node_inputs.data(), api_node_inputs.size())); + std::vector api_node_inputs = api_node.GetInputs(); + ASSERT_EQ(api_node_inputs.size(), input_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_inputs, input_node_args); // Check node outputs const auto output_node_args = node->OutputDefs(); - size_t api_node_num_outputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumOutputs(api_node, &api_node_num_outputs)); - ASSERT_EQ(api_node_num_outputs, output_node_args.size()); - - std::vector api_node_outputs(api_node_num_outputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetOutputs(api_node, api_node_outputs.data(), api_node_outputs.size())); + std::vector api_node_outputs = api_node.GetOutputs(); + ASSERT_EQ(api_node_outputs.size(), output_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args); // Check node attributes const auto& node_attrs = node->GetAttributes(); if (!node_attrs.empty()) { - size_t api_num_node_attributes = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(api_node, &api_num_node_attributes)); - - std::vector api_node_attributes(api_num_node_attributes); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, api_node_attributes.data(), api_node_attributes.size())); + std::vector api_node_attributes = api_node.GetAttributes(); size_t attr_idx = 0; for (const auto& node_attr : node_attrs) { - const OrtOpAttr* api_node_attr = api_node_attributes[attr_idx]; + auto api_node_attr = api_node_attributes[attr_idx]; ASSERT_NE(api_node_attr, nullptr); - api_node_attr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(api_node, node_attr.first.c_str(), &api_node_attr)); + auto status = api_node.GetAttributeByName(node_attr.first, api_node_attr); + ASSERT_TRUE(status.IsOK()); ASSERT_NE(api_node_attr, nullptr); - const char* api_node_attr_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(api_node_attr, &api_node_attr_name)); - ASSERT_STREQ(api_node_attr_name, node_attr.first.c_str()); - - OrtOpAttrType api_node_attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + auto api_node_attr_name = api_node_attr.GetName(); + ASSERT_EQ(api_node_attr_name, node_attr.first); + // XXX: Investigate why not // It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping. // In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here. // TODO: Once we add support for ORT_OP_ATTR_TENSOR, we should be able to just fail if OpAttr_GetType // returns an error. - OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type); - if (status != nullptr) { - Ort::GetApi().ReleaseStatus(status); - continue; - } + OrtOpAttrType api_node_attr_type = api_node_attr.GetType(); ONNX_NAMESPACE::AttributeProto_AttributeType node_attr_type = node_attr.second.type(); switch (node_attr_type) { @@ -1091,7 +998,7 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. - ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); + FAIL() << "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."; } attr_idx++; } @@ -1105,41 +1012,19 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); - size_t api_num_node_implicit_inputs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumImplicitInputs(api_node, &api_num_node_implicit_inputs)); - ASSERT_EQ(api_num_node_implicit_inputs, implicit_input_node_args.size()); - - std::vector api_node_implicit_inputs(api_num_node_implicit_inputs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetImplicitInputs(api_node, api_node_implicit_inputs.data(), - api_node_implicit_inputs.size())); - + std::vector api_node_implicit_inputs = api_node.GetImplicitInputs(); + ASSERT_EQ(api_node_implicit_inputs.size(), implicit_input_node_args.size()); CheckValueInfosCApi(graph_viewer, api_node_implicit_inputs, implicit_input_node_args); // Recursively check subgraphs. - size_t api_num_node_subgraphs = 0; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); - ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); - - std::vector api_node_subgraphs(api_num_node_subgraphs); - std::vector api_subgraph_attr_names(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), - api_subgraph_attr_names.data())); - - for (const auto& [attr_name, subgraph] : node_subgraphs_map) { - // find index of this subgraph. - size_t api_subgraph_idx = api_num_node_subgraphs; - for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { - if (api_subgraph_attr_names[subgraph_idx] == attr_name) { - api_subgraph_idx = subgraph_idx; - break; - } - } - ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); - - // Recursively check the subgraph - auto subgraph_viewer = std::make_unique(*subgraph); - const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; - CheckGraphCApi(*subgraph_viewer, *api_subgraph); + std::vector api_node_subgraphs = api_node.GetSubgraphs(); + ASSERT_EQ(api_node_subgraphs.size(), node_subgraphs_map.size()); + + for (const auto& name_subgraph : api_node_subgraphs) { + auto hit = node_subgraphs_map.find(name_subgraph.attr_name); + ASSERT_NE(node_subgraphs_map.end(), hit); + auto subgraph_viewer = std::make_unique(*hit->second); + CheckGraphCApi(*subgraph_viewer, *name_subgraph.sub_graph); } } } diff --git a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc index 63652d8835e77..2e2bce97f0cb9 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc @@ -56,19 +56,19 @@ static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_ // Sum the number of inputs with a producer node. num_input_edges = 0; - for (const OrtValueInfo* input : inputs) { + for (const OrtValueInfo* ort_input : inputs) { + Ort::ConstValueInfo input{ort_input}; if (input == nullptr) continue; // Skip missing optional input - const OrtNode* producer_node = nullptr; - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueProducer(input, &producer_node, /*output_index*/ nullptr)); - num_input_edges += static_cast(producer_node != nullptr); + auto producer_info = input.GetProducerNode(); + num_input_edges += static_cast(producer_info.node != nullptr); } return Ort::Status{nullptr}; } // Get all output nodes that consume an output from the given node. -static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { +static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { const OrtApi& ort_api = Ort::GetApi(); size_t num_outputs = 0; @@ -77,23 +77,17 @@ static Ort::Status GetOutputNodes(const OrtNode* node, std::vector outputs(num_outputs); RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - std::vector output_nodes; + std::vector output_nodes; output_nodes.reserve(num_outputs); // May have more than `num_outputs` // Gather the OrtNode consumers of every output. - for (const OrtValueInfo* output : outputs) { + for (const OrtValueInfo* ort_output : outputs) { + Ort::ConstValueInfo output{ort_output}; if (output == nullptr) continue; // Skip missing optional output - size_t num_consumers = 0; - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueNumConsumers(output, &num_consumers)); - - std::vector node_consumers(num_consumers, nullptr); - std::vector input_indices(num_consumers, 0); - RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueConsumers(output, node_consumers.data(), - input_indices.data(), num_consumers)); - - for (const OrtNode* consumer : node_consumers) { - output_nodes.push_back(consumer); + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); } } @@ -108,77 +102,85 @@ static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, const std::function& comp) { const OrtApi& ort_api = Ort::GetApi(); - // Get all nodes - size_t num_nodes = 0; - RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + try { + // Get all nodes + size_t num_nodes = 0; + RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); - if (num_nodes == 0) { - return Ort::Status{nullptr}; // Nothing to sort. - } + if (num_nodes == 0) { + return Ort::Status{nullptr}; // Nothing to sort. + } - std::vector nodes(num_nodes); - RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); + std::vector nodes(num_nodes); + RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); - // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. - size_t max_node_id = 0; - for (const OrtNode* node : nodes) { - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - max_node_id = std::max(max_node_id, node_id); - } + // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. + size_t max_node_id = 0; + for (const OrtNode* node : nodes) { + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + max_node_id = std::max(max_node_id, node_id); + } - std::vector in_degree(max_node_id + 1, 0); - std::vector topo_order; - VisitorPriorityQueue to_visit(comp); + std::vector in_degree(max_node_id + 1, 0); + std::vector topo_order; + VisitorPriorityQueue to_visit(comp); - topo_order.reserve(num_nodes); + topo_order.reserve(num_nodes); - // Initialize in_degree and initial nodes to visit first. - for (const OrtNode* node : nodes) { - size_t input_edge_count = 0; - RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); + // Initialize in_degree and initial nodes to visit first. + for (const OrtNode* node : nodes) { + size_t input_edge_count = 0; + RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - in_degree[node_id] = input_edge_count; - if (input_edge_count == 0) { - to_visit.push(node); + in_degree[node_id] = input_edge_count; + if (input_edge_count == 0) { + to_visit.push(node); + } } - } - while (!to_visit.empty()) { - const OrtNode* current_node = to_visit.top(); - to_visit.pop(); + while (!to_visit.empty()) { + const OrtNode* current_node = to_visit.top(); + to_visit.pop(); - if (!current_node) continue; + if (!current_node) continue; - if (enter) { - enter(current_node); - } + if (enter) { + enter(current_node); + } - std::vector output_nodes; - RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); + std::vector output_nodes; + RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); - for (const OrtNode* output_node : output_nodes) { - size_t output_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); + for (const auto& output_node : output_nodes) { + size_t output_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); - auto& node_in_degree = in_degree[output_node_id]; - node_in_degree--; + auto& node_in_degree = in_degree[output_node_id]; + node_in_degree--; - if (node_in_degree == 0) { - to_visit.push(output_node); + if (node_in_degree == 0) { + to_visit.push(output_node); + } } - } - size_t current_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); - topo_order.push_back(current_node_id); - } + size_t current_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); + topo_order.push_back(current_node_id); + } - if (num_nodes != topo_order.size()) { - return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + if (num_nodes != topo_order.size()) { + return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status; + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status; } return Ort::Status{nullptr}; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index bc01135fbbf1e..6131eff92ac78 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -588,93 +588,6 @@ TEST(InferenceSessionTests, RequestLoadCancellation) { } } -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { - if (f_arg.size() != s_arg.size()) { - std::cout << "Sizes differ: f_arg size: " << f_arg.size() << " s_arg size: " << s_arg.size() << std::endl; - return false; - } - - for (size_t i = 0; i < f_arg.size(); ++i) { - const onnxruntime::NodeArg* x = f_arg[i]; - const onnxruntime::NodeArg* y = s_arg[i]; - if ((x->Shape() == nullptr) ^ (y->Shape() == nullptr)) { - return false; - } - if (!x->Shape()) { - continue; - } - auto x_shape = utils::GetTensorShapeFromTensorShapeProto(*x->Shape()); - auto y_shape = utils::GetTensorShapeFromTensorShapeProto(*y->Shape()); - if (x->Name() == y->Name() && x_shape == y_shape && *x->Type() == *y->Type()) { - continue; - } - return false; - } - - return true; -} - -TEST(InferenceSessionTests, ModelMetadata) { - SessionOptions so; - - so.session_logid = "InferenceSessionTests.ModelMetadata"; - InferenceSession session_object{so, GetEnvironment()}; - auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"); - ASSERT_STATUS_OK(session_object.Load(model_uri)); - - std::shared_ptr p_model; - ASSERT_STATUS_OK(onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger())); - const onnxruntime::Graph& graph = p_model->MainGraph(); - - // 1. first test the model meta - { - auto retval = session_object.GetModelMetadata(); - ASSERT_TRUE(retval.first.IsOK()); - const ModelMetadata* m = retval.second; - ASSERT_TRUE(m->custom_metadata_map == p_model->MetaData() && - m->description == p_model->DocString() && - m->domain == p_model->Domain() && - m->graph_name == graph.Name() && - m->producer_name == p_model->ProducerName() && - m->version == p_model->ModelVersion()); - } - - { - // 2. test inputs - auto& inputs = graph.GetInputs(); - auto weights = graph.GetAllInitializedTensors(); - - // skip the weights - InputDefList inputs_no_weights; - for (auto& elem : inputs) { - if (weights.find(elem->Name()) != weights.end()) { - continue; - } else { - inputs_no_weights.push_back(elem); - } - } - - auto retval = session_object.GetModelInputs(); - std::cout << "weights size: " << weights.size() - << " inputs.size(): " << inputs.size() - << " from session: " << retval.second->size() << std::endl; - ASSERT_TRUE(retval.first.IsOK()); - ASSERT_TRUE(Compare(inputs_no_weights, *retval.second)); - } - - // 3. test outputs - { - auto retval = session_object.GetModelOutputs(); - ASSERT_TRUE(retval.first.IsOK()); - - auto& outputs = graph.GetOutputs(); - retval = session_object.GetModelOutputs(); - ASSERT_TRUE(retval.first.IsOK()); - ASSERT_TRUE(Compare(outputs, *retval.second)); - } -} -#endif TEST(InferenceSessionTests, CheckRunLogger) { if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { GTEST_SKIP() << "Skipping the test"; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 6ad21fa9f5cf5..a9d6273ae2f20 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -11,6 +11,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/op_kernel.h" #include "core/framework/bfc_arena.h" +#include "core/framework/ep_context_options.h" #include "core/framework/session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -504,7 +505,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, ASSERT_STATUS_OK( partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal, - EpContextModelGenerationOptions{}, + epctx::ModelGenOptions{}, debug_graph_fn)); verifier_fn(graph); diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 9327d86966981..55fc4f42bec64 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -26,44 +26,6 @@ class ONNXModelsTest : public ::testing::Test { std::unique_ptr logger_; }; -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -// Tests that Resolve() properly clears the state of topological sorted nodes, -// inputs, outputs and valueInfo. -// Assumes the graph passed in has been previously resolved. -static void TestResolve(onnxruntime::Graph& graph) { - GraphViewer graph_viewer(graph); - auto& nodes_before = graph_viewer.GetNodesInTopologicalOrder(); - auto& inputs_before = graph.GetInputs(); - auto& outputs_before = graph.GetOutputs(); - auto& value_info_before = graph.GetValueInfo(); - - // Touch the graph to force Resolve() to recompute. - graph.SetGraphResolveNeeded(); - graph.SetGraphProtoSyncNeeded(); - ASSERT_STATUS_OK(graph.Resolve()); - - GraphViewer graph_viewer_2(graph); - auto& nodes_after = graph_viewer_2.GetNodesInTopologicalOrder(); - auto& inputs_after = graph.GetInputs(); - auto& outputs_after = graph.GetOutputs(); - auto& value_info_after = graph.GetValueInfo(); - - // Multiple calls to Resolve() should not alter the sorted nodes, - // inputs, outputs and valueInfo. The internal state should be - // cleared. - EXPECT_EQ(nodes_before, nodes_after); - EXPECT_EQ(inputs_before, inputs_after); - EXPECT_EQ(outputs_before, outputs_after); - EXPECT_EQ(value_info_before, value_info_after); -} - -TEST_F(ONNXModelsTest, squeeze_net) { - // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"), model, nullptr, *logger_)); - TestResolve(model->MainGraph()); -} -#endif TEST_F(ONNXModelsTest, non_existing_model) { // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located @@ -96,76 +58,6 @@ class ONNXModelsTest1 : public ::testing::TestWithParam { return oss.str(); } }; -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -TEST_F(ONNXModelsTest, bvlc_alexnet_1) { - using ::google::protobuf::io::CodedInputStream; - using ::google::protobuf::io::FileInputStream; - using ::google::protobuf::io::ZeroCopyInputStream; - int fd; - ASSERT_STATUS_OK(Env::Default().FileOpenRd(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), fd)); - ASSERT_TRUE(fd > 0); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); - // Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB. - coded_input->SetTotalBytesLimit(INT_MAX); - ModelProto model_proto; - bool result = model_proto.ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - EXPECT_TRUE(result); - ASSERT_STATUS_OK(Env::Default().FileClose(fd)); - - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), model, nullptr, - *logger_)); - - // Check the graph input/output/value_info should have the same size as specified in the model file. - EXPECT_EQ(static_cast(model_proto.graph().value_info_size()), model->MainGraph().GetValueInfo().size()); - EXPECT_EQ(static_cast(model_proto.graph().input_size()), model->MainGraph().GetInputs().size() + model->MainGraph().GetAllInitializedTensors().size()); - EXPECT_EQ(static_cast(model_proto.graph().output_size()), model->MainGraph().GetOutputs().size()); - TestResolve(model->MainGraph()); -} - -TEST_P(ONNXModelsTest1, LoadFromFile) { - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(GetModelFileName(), model, nullptr, - *logger_)); - TestResolve(model->MainGraph()); -} - -TEST_P(ONNXModelsTest1, LoadFromProtobuf) { - using ::google::protobuf::io::CodedInputStream; - using ::google::protobuf::io::FileInputStream; - using ::google::protobuf::io::ZeroCopyInputStream; - int fd; - ASSERT_STATUS_OK(Env::Default().FileOpenRd(GetModelFileName(), fd)); - ASSERT_TRUE(fd > 0); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); - coded_input->SetTotalBytesLimit(INT_MAX); - ModelProto model_proto; - bool result = model_proto.ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - ASSERT_TRUE(result); - ASSERT_STATUS_OK(Env::Default().FileClose(fd)); - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, - *logger_)); - TestResolve(model->MainGraph()); -} - -#ifndef DISABLE_CONTRIB_OPS -INSTANTIATE_TEST_SUITE_P(ONNXModelsTests, - ONNXModelsTest1, - ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("tiny_yolov2"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); -#else -INSTANTIATE_TEST_SUITE_P(ONNXModelsTests, - ONNXModelsTest1, - ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); -#endif - -#endif // test a model that conforms to ONNX IR v4 where there are initializers that are not graph inputs. // a NodeArg should be created for all initializers in this case. diff --git a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp index fad804f3ce305..3ed283d54f41d 100644 --- a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp @@ -31,10 +31,156 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { std::uniform_real_distribution distrib_f32_; MatrixGuardBuffer inputB_, inputZp_, refB_, packedBuffer_; MatrixGuardBuffer inputScale_, refScale_; - MatrixGuardBuffer inputBlkSum_, refBlkSum_; + MatrixGuardBuffer inputBlkSum_, refBlkSum_, refBlkUnsignedQuantAZeroPointCorrection_; +#ifdef MLAS_TARGET_ARM64 template - void PrepackB(const uint8_t* src, uint8_t* dst) { + void PrepackB(const uint8_t* src, uint8_t* dst, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n * ldb + k; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n * BlkCount + k; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n * ldb + k; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n * BlkCount + k; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + } +#else // not MLAS_TARGET_ARM64 + template + void PrepackB(const uint8_t* src, uint8_t* dst, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); size_t n = 0; for (; n + 4 <= N; n += 4) { @@ -65,7 +211,9 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { } template - void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum) { + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; constexpr size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; @@ -174,10 +322,15 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { } } } +#endif // MLAS_TARGET_ARM64 template void CheckBlkSum(const float* packedBlkSum, const float* refBlkSum) { - size_t BlkCount = (K + BlkLen - 1) / BlkLen; + if (refBlkSum == nullptr) { + return; + } + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; for (size_t n = 0; n < N; ++n) { for (size_t k = 0; k < BlkCount; ++k) { @@ -198,6 +351,7 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { constexpr size_t PackBCount = N * Ldb; constexpr size_t ScaleCount = BlkCount * N; const size_t BufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, Bits, BlkLen, hasZp, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; const auto* inputB = inputB_.GetFilledBuffer(PackBCount, [this](uint8_t* p, size_t t) { for (size_t i = 0; i < t; i++) { @@ -222,25 +376,36 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { auto* refB = refB_.GetBuffer(PackBCount, true); auto* refScale = refScale_.GetBuffer(ScaleCount, true); auto* refBlkSum = refBlkSum_.GetBuffer(((N + 15) & (~15)) * BlkCount, true); + auto* refBlkUnsignedQuantAZeroPointCorrection = isQuantAUnsigned ? refBlkUnsignedQuantAZeroPointCorrection_.GetBuffer(((N + 15) & (~15)) * BlkCount, true) : nullptr; + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); + + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, - inputScale, hasZp, nullptr, nullptr); + inputScale, hasZp, inputZp, nullptr); + MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, hasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + PrepackB(inputB, refB, refBlkUnsignedQuantAZeroPointCorrection); + PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum, refBlkUnsignedQuantAZeroPointCorrection); - PrepackB(inputB, refB); - PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum); - - CheckB(refB, reinterpret_cast(packedQuantB.PackedQuantBData)); - CheckScale(refScale, packedQuantB.PackedQuantBScale); - CheckBlkSum(refBlkSum, packedQuantB.QuantBBlkSum); + CheckB(reinterpret_cast(packedQuantB.PackedQuantBData), refB); + CheckScale(packedQuantB.PackedQuantBScale, refScale); + CheckBlkSum(packedQuantB.QuantBBlkSum, refBlkSum); + CheckBlkSum(packedQuantB.BlkUnsignedQuantAZeroPointCorrection, refBlkUnsignedQuantAZeroPointCorrection); } public: @@ -298,31 +463,203 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { Execute<1, 1, 256, 64>(); Execute<16, 4, 16, 64>(); - Execute<32, 4, 16, 64>(); - Execute<64, 4, 16, 64>(); - Execute<128, 4, 16, 64>(); + Execute<32, 8, 16, 64>(); + Execute<64, 12, 32, 64>(); + Execute<128, 16, 64, 64>(); - Execute<15, 5, 16, 64>(); - Execute<15, 5, 32, 64>(); + Execute<15, 3, 16, 64>(); + Execute<15, 4, 32, 64>(); Execute<15, 5, 64, 64>(); - Execute<15, 5, 128, 64>(); - Execute<15, 5, 256, 64>(); - + Execute<15, 6, 128, 64>(); + Execute<15, 7, 256, 64>(); + Execute<15, 8, 16, 64>(); + Execute<15, 9, 16, 64>(); + + Execute<17, 3, 16, 64>(); + Execute<17, 4, 32, 64>(); + Execute<17, 5, 64, 64>(); + Execute<17, 6, 128, 64>(); + Execute<17, 7, 256, 64>(); Execute<17, 8, 16, 64>(); - Execute<17, 8, 32, 64>(); - Execute<17, 8, 64, 64>(); - Execute<17, 8, 128, 64>(); - Execute<17, 8, 256, 64>(); + Execute<17, 9, 16, 64>(); Execute<159, 16, 16, 64>(); Execute<160, 17, 32, 64>(); Execute<161, 15, 64, 64>(); Execute<160, 17, 128, 64>(); Execute<159, 16, 256, 64>(); + Execute<3072, 128, 16, 64>(); } } }; +class MlasSQ8BitQuantAKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer workspace_, refQuantA_; + MatrixGuardBuffer inputA_, refScale_, refBlkSum_; + + template + void QuantA(const float* inputA, uint8_t* quantA, float* scalePtr, float* blkSum, bool quantAUnsigned) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t input_lda = K; + + constexpr size_t Bits = 8; + constexpr size_t output_lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + float vAbsMax = 0.f; + for (size_t k = 0; k < std::min(BlkLen, K - j * BlkLen); ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + vAbsMax = std::max(vAbsMax, fabsf(inputA[input_idx])); + } + + float scale = vAbsMax / 127.f; + float invScale = vAbsMax == 0.f ? 0.f : 127.f / vAbsMax; + scalePtr[i * BlkCount + j] = scale; + + float vSum = 0.f; + for (size_t k = 0; k < BlkLen; ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + size_t output_idx = i * output_lda + j * BlkLen + k; + if (k < std::min(BlkLen, K - j * BlkLen)) { + const auto input_val = inputA[input_idx]; + // Round to nearest, ties away from zero + // float v = std::clamp(std::roundf(input_val * invScale), -128.f, 127.f); + + // Round to nearest, ties to even + float v = std::clamp(std::nearbyint(input_val * invScale), -128.f, 127.f); + + if (quantAUnsigned) { + quantA[output_idx] = static_cast(v + 128.f); + vSum += v + 128.f; + } else { + reinterpret_cast(quantA)[output_idx] = static_cast(v); + vSum += v; + } + } else { + quantA[output_idx] = 0; + } + } + blkSum[i * BlkCount + j] = vSum * scale; + } + } + } + + template + void CheckQuantA(const uint8_t* quantA, const uint8_t* refQuantA) { + constexpr size_t lda = (K + BlkLen - 1) & (~(BlkLen - 1)); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < lda; ++j) { + size_t idx = i * lda + j; + ASSERT_EQ(quantA[idx], refQuantA[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void CheckScale(const float* scale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + size_t idx = i * BlkCount + j; + ASSERT_EQ(scale[idx], refScale[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void TestQuantA() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + const auto* dispatch = GetMlasPlatform().QNBitGemmDispatch; + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackACount = M * Lda; + constexpr size_t ScaleCount = M * BlkCount; + const size_t BufferSize = MlasQNBitGemmBatchWorkspaceSize(M, 1, K, 1, Bits, BlkLen, true, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + + const auto* inputA = inputA_.GetFilledBuffer(M * K, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* workspace = workspace_.GetBuffer(BufferSize, true); + auto* refQuantA = refQuantA_.GetBuffer(PackACount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(ScaleCount, true); + + const size_t Alignment = dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, SQNBIT_CompInt8); + const uintptr_t WorkspaceAddress = reinterpret_cast(workspace); + auto* quantAPtr = reinterpret_cast((WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); + auto* scaleAPtr = reinterpret_cast(quantAPtr + PackACount); + auto* blkSumAPtr = scaleAPtr + ScaleCount; + + for (size_t i = 0; i < M; ++i) { + dispatch->QuantizeARowComputeBlkSum_CompInt8(BlkLen, inputA + i * K, K, quantAPtr + i * Lda, scaleAPtr + i * BlkCount, blkSumAPtr + i * BlkCount); + } + + QuantA(inputA, refQuantA, refScale, refBlkSum, isQuantAUnsigned); + CheckQuantA(reinterpret_cast(quantAPtr), refQuantA); + CheckScale(scaleAPtr, refScale); + CheckScale(blkSumAPtr, refBlkSum); + } + + public: + MlasSQ8BitQuantAKernelTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitQuantA"; + } + + void ExecuteShort(void) override { + TestQuantA<1, 16, 16>(); + TestQuantA<1, 1, 32>(); + TestQuantA<1, 1, 64>(); + TestQuantA<1, 1, 128>(); + TestQuantA<1, 1, 256>(); + + TestQuantA<4, 16, 16>(); + TestQuantA<8, 32, 16>(); + TestQuantA<12, 64, 32>(); + TestQuantA<16, 128, 64>(); + + TestQuantA<3, 15, 16>(); + TestQuantA<4, 15, 32>(); + TestQuantA<5, 15, 64>(); + TestQuantA<6, 15, 128>(); + TestQuantA<7, 15, 256>(); + TestQuantA<8, 15, 16>(); + TestQuantA<9, 15, 16>(); + + TestQuantA<3, 17, 16>(); + TestQuantA<4, 17, 32>(); + TestQuantA<5, 17, 64>(); + TestQuantA<6, 17, 128>(); + TestQuantA<7, 17, 256>(); + TestQuantA<8, 17, 16>(); + TestQuantA<9, 17, 16>(); + + TestQuantA<2, 159, 16>(); + TestQuantA<3, 159, 16>(); + TestQuantA<17, 160, 32>(); + TestQuantA<15, 161, 64>(); + TestQuantA<17, 160, 128>(); + TestQuantA<16, 159, 256>(); + + TestQuantA<1, 3072, 16>(); + } +}; + class MlasSQ8BitGemmKernelTest : public MlasTestBase { private: unsigned int seed_; @@ -383,9 +720,6 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { } }); - int q_rows, q_cols; - MlasBlockwiseQuantizedShape((int)BlkLen, true, (int)K, (int)N, q_rows, q_cols); - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; MlasBlockwiseQuantizedBufferSizes<8>((int)(BlkLen), true, (int)K, (int)N, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); @@ -420,24 +754,33 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { size_t bufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, 8, BlkLen, HasZp, SQNBIT_CompInt8); auto* packedBuffer = packedBuffer_.GetBuffer(bufferSize, true); + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, - inputScale, HasZp, nullptr, nullptr); + inputScale, HasZp, inputZp, nullptr); + MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, HasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, true); auto* C = C_.GetBuffer(M * ldc, true); auto* ref = ref_.GetBuffer(M * ldc, true); - auto* bias = HasBias ? bias_.GetFilledBuffer(N, [this](float* p, size_t t) { + auto* bias = HasBias ? bias_.GetFilledBuffer(N, [](float* p, size_t t) { for (size_t i = 0; i < t; i++) { - p[i] = this->distrib_f32_(this->gen_); + p[i] = (float)(5 + i); } }) : nullptr; @@ -473,15 +816,16 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { template void Execute(void) { - TestSQ8BitGemmKernel(); TestSQ8BitGemmKernel(); - TestSQ8BitGemmKernel(); TestSQ8BitGemmKernel(); + + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); } void ExecuteShort(void) override { - Execute<1, 1, 1, 16>(); - Execute<7, 128, 4, 16>(); + Execute<1, 16, 1, 16>(); + Execute<7, 2, 4, 16>(); Execute<8, 497, 5, 16>(); Execute<1, 3072, 128, 16>(); Execute<2, 3072, 128, 16>(); @@ -515,6 +859,7 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); count += MlasDirectShortExecuteTests::RegisterShortExecute(); } return count; diff --git a/onnxruntime/test/optimizer/resnet50_fusion_test.cc b/onnxruntime/test/optimizer/resnet50_fusion_test.cc index 5cb0206156a84..7e6677c8e1ddf 100644 --- a/onnxruntime/test/optimizer/resnet50_fusion_test.cc +++ b/onnxruntime/test/optimizer/resnet50_fusion_test.cc @@ -16,7 +16,6 @@ namespace onnxruntime { namespace test { -// #define ORT_RUN_EXTERNAL_ONNX_TESTS // #define MLAS_F16VEC_INTRINSICS_SUPPORTED #define MODEL_FOLDER ORT_TSTR("testdata/transform/") @@ -28,54 +27,7 @@ class ResNet50FusionTests : public ::testing::Test { } std::unique_ptr logger; }; -#if defined(ORT_RUN_EXTERNAL_ONNX_TESTS) -TEST_F(ResNet50FusionTests, FuseConvIntegrationTest) { - std::basic_string fp32_model_path = ORT_TSTR("../models/opset10/Resnet50_Fusion_Testing/resnet50.onnx"); - std::shared_ptr fp32_model; - std::basic_string fp16_model_path = ORT_TSTR("../models/opset10/Resnet50_Fusion_Testing_fp16/resnet50.fp16.onnx"); - std::shared_ptr fp16_model; - if (Model::Load(fp32_model_path, fp32_model, nullptr, *logger) != Status::OK()) { - GTEST_SKIP() << "Failed to load model: " << fp32_model_path; - } - if (Model::Load(fp16_model_path, fp16_model, nullptr, *logger) != Status::OK()) { - GTEST_SKIP() << "Failed to load model: " << fp16_model_path; - } - // ASSERT_STATUS_OK(Model::Load(fp32_model_path, fp32_model, nullptr, *logger)); - Graph& fp32_graph = fp32_model->MainGraph(); - for (auto& node : fp32_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCpuExecutionProvider); - } - Graph& fp16_graph = fp16_model->MainGraph(); - for (auto& node : fp16_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCpuExecutionProvider); - } - // std::cout << "-------Op Counts Before Fusion---------" << std::endl; - std::map fp32_op_count = CountOpsInGraph(fp32_graph); - std::map fp16_op_count = CountOpsInGraph(fp16_graph); - for (auto& op : fp32_op_count) { - // std::cout << op.first << " " << op.second << std::endl; - ASSERT_EQ(op.second, fp16_op_count[op.first]); - } - onnxruntime::GraphTransformerManager graph_transformation_mgr_32{5}; - ASSERT_STATUS_OK(graph_transformation_mgr_32.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_32.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_32.ApplyTransformers(fp32_graph, TransformerLevel::Level3, *logger)); - ASSERT_STATUS_OK(Model::Save(*fp32_model, ORT_TSTR("resnet50_fused.onnx"))); - - onnxruntime::GraphTransformerManager graph_transformation_mgr_16{5}; - ASSERT_STATUS_OK(graph_transformation_mgr_16.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_16.Register(std::make_unique(), TransformerLevel::Level3)); - ASSERT_STATUS_OK(graph_transformation_mgr_16.ApplyTransformers(fp16_graph, TransformerLevel::Level3, *logger)); - ASSERT_STATUS_OK(Model::Save(*fp16_model, ORT_TSTR("resnet50_fp16_fused.onnx"))); - // std::cout << "-------Op Counts After Fusion---------" << std::endl; - fp32_op_count = CountOpsInGraph(fp32_graph); - fp16_op_count = CountOpsInGraph(fp16_graph); - // for (auto& op : fp32_op_count) { - // ASSERT_EQ(op.second, fp16_op_count[op.first]); - // } -} -#endif // defined(ORT_RUN_EXTERNAL_ONNX_TESTS) TEST_F(ResNet50FusionTests, FuseConvAddReluUnitTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu_fp16.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index a17982ecb5eab..cf49601e6c671 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "core/session/onnxruntime_c_api.h" @@ -64,6 +65,157 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace test { + +// Models verified to exist in both VM and Zoo with identical checksums +// These 20 unique models have been confirmed as public (33 instances across opsets) +static const std::unordered_set VERIFIED_PUBLIC_MODELS = { + "AlexNet", + "BERT-Squad", + "CaffeNet", + "DenseNet-121", + "Emotion FERPlus", + "Faster R-CNN R-50-FPN", + "GoogleNet", + "Inception-1", + "Inception-2", + "Mask R-CNN R-50-FPN", + "MNIST", + "MobileNet v2-7", + "R-CNN ILSVRC13", + "ShuffleNet-v1", + "SqueezeNet 1.0", + "SqueezeNet 1.1", + "SSD", + "VGG 19-caffe2", + "YOLOv3", + "ZFNet-512"}; + +// All ONNX Model Zoo models (always safe as they're public) +// Total: 158 models from https://github.com/onnx/models +static const std::unordered_set ONNX_ZOO_MODELS = { + // Verified models (20 unique) + "AlexNet", + "BERT-Squad", + "CaffeNet", + "DenseNet-121", + "Emotion FERPlus", + "Faster R-CNN R-50-FPN", + "GoogleNet", + "Inception-1", + "Inception-2", + "Mask R-CNN R-50-FPN", + "MNIST", + "MobileNet v2-7", + "R-CNN ILSVRC13", + "ShuffleNet-v1", + "SqueezeNet 1.0", + "SqueezeNet 1.1", + "SSD", + "VGG 19-caffe2", + "YOLOv3", + "ZFNet-512", + // Additional Zoo-only models (138) + "AlexNet-int8", + "BERT-Squad-int8", + "BiDAF", + "BiDAF-int8", + "CaffeNet-int8", + "CaffeNet-qdq", + "Candy", + "DenseNet-121-12", + "DenseNet-121-12-int8", + "EfficientNet-Lite4", + "EfficientNet-Lite4-int8", + "EfficientNet-Lite4-qdq", + "Emotion FERPlus int8", + "FCN ResNet-50", + "FCN ResNet-50-int8", + "FCN ResNet-50-qdq", + "FCN ResNet-101", + "Faster R-CNN R-50-FPN-fp32", + "Faster R-CNN R-50-FPN-int8", + "Faster R-CNN R-50-FPN-qdq", + "GoogleNet-int8", + "GoogleNet-qdq", + "GPT-2", + "GPT-2-LM-HEAD", + "Inception-1-int8", + "Inception-1-qdq", + "LResNet100E-IR", + "LResNet100E-IR-int8", + "Mask R-CNN R-50-FPN-fp32", + "Mask R-CNN R-50-FPN-int8", + "Mask R-CNN R-50-FPN-qdq", + "MNIST-12", + "MNIST-12-int8", + "MobileNet v2-1.0", + "MobileNet v2-1.0-fp32", + "MobileNet v2-1.0-int8", + "MobileNet v2-1.0-qdq", + "Mosaic", + "Pointilism", + "Rain Princess", + "ResNet18", + "ResNet18-v2", + "ResNet34", + "ResNet34-v2", + "ResNet50", + "ResNet50-caffe2", + "ResNet50-fp32", + "ResNet50-int8", + "ResNet50-qdq", + "ResNet50-v2", + "ResNet101", + "ResNet101-v2", + "ResNet101_DUC_HDC", + "ResNet101_DUC_HDC-12", + "ResNet101_DUC_HDC-12-int8", + "ResNet152", + "ResNet152-v2", + "ResNet-preproc", + "RetinaNet (ResNet101 backbone)", + "RoBERTa-BASE", + "RoBERTa-SequenceClassification", + "ShuffleNet-v2", + "ShuffleNet-v2-fp32", + "ShuffleNet-v2-int8", + "ShuffleNet-v2-qdq", + "SqueezeNet 1.0-int8", + "SqueezeNet 1.0-qdq", + "SSD-int8", + "SSD-qdq", + "SSD-MobilenetV1", + "SSD-MobilenetV1-12", + "SSD-MobilenetV1-12-int8", + "SSD-MobilenetV1-12-qdq", + "Super_Resolution", + "T5-decoder-with-lm-head", + "T5-encoder", + "Tiny YOLOv2", + "Tiny YOLOv3", + "Udnie", + "VGG 16", + "VGG 16-bn", + "VGG 16-fp32", + "VGG 16-int8", + "VGG 16-qdq", + "VGG 19", + "VGG 19-bn", + "version-RFB-320", + "version-RFB-320-int8", + "version-RFB-640", + "YOLOv2", + "YOLOv3-12", + "YOLOv3-12-int8", + "YOLOv4", + "ZFNet-512-int8", + "ZFNet-512-qdq"}; + +// Helper function to check if a model is allowed +inline bool IsModelAllowed(const std::string& model_name) { + return ONNX_ZOO_MODELS.count(model_name) > 0; +} + // parameter is provider_name + "_" + model_path class ModelTest : public testing::TestWithParam> {}; @@ -656,15 +808,12 @@ ::std::vector<::std::basic_string> GetParameterStrings() { // Same as the above, except this one is for large models #if defined(NDEBUG) || defined(RUN_MODELTEST_IN_DEBUG_MODE) #ifdef _WIN32 - ORT_STRING_VIEW model_test_root_path = ORT_TSTR("..\\models"); - // thus, only the root path should be mounted. ORT_STRING_VIEW model_zoo_path = ORT_TSTR("..\\models\\zoo"); #else - ORT_STRING_VIEW model_test_root_path = ORT_TSTR("../models"); ORT_STRING_VIEW model_zoo_path = ORT_TSTR("../models/zoo"); #endif for (auto p : kvp.second) { - paths.push_back(ConcatPathComponent(model_test_root_path, p)); + // ONLY use Model Zoo path - guaranteed public models with public test data paths.push_back(ConcatPathComponent(model_zoo_path, p)); } #endif @@ -750,6 +899,13 @@ ::std::vector<::std::basic_string> GetParameterStrings() { std::basic_string test_case_name = path.parent_path().filename().native(); if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5); + + // Check if model is in the public whitelist + std::string model_name_str = ToUTF8String(test_case_name); + if (!IsModelAllowed(model_name_str)) { + continue; // Skip models not in whitelist + } + if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end()) continue; diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc new file mode 100644 index 0000000000000..d415548876153 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" + +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/common/random_generator.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include +#include + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; +namespace onnxruntime { + +namespace test { +size_t countFilesInDirectory(const std::string& dir_path) { + return std::distance(std::filesystem::directory_iterator(dir_path), std::filesystem::directory_iterator{}); +} + +TEST(NvExecutionProviderTest, RuntimeCaching) { + PathString model_name = ORT_TSTR("nv_execution_provider_runtime_caching.onnx"); + PathString model_name_ctx = ORT_TSTR("nv_execution_provider_runtime_caching_ctx.onnx"); + auto model_name_ctx_str = PathToUTF8(model_name_ctx); + clearFileIfExists(model_name_ctx); + std::string graph_name = "test"; + std::vector dims = {1, 3, 2}; + std::string runtime_cache_name = "./runtime_cache/"; + if (std::filesystem::exists(runtime_cache_name)) { + std::filesystem::remove_all(runtime_cache_name); + } + CreateBaseModel(model_name, graph_name, dims); + // AOT time + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", runtime_cache_name.c_str()}}); + Ort::Session session_object(*ort_env, model_name.c_str(), so); + + auto io_binding = generate_io_binding(session_object); + session_object.Run(run_options, io_binding); + } + // the cache will be dumped to disk upon session destruction + ASSERT_TRUE(std::filesystem::exists(runtime_cache_name.c_str())); + ASSERT_TRUE(1 == countFilesInDirectory(runtime_cache_name)); + + // use existing cache + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", runtime_cache_name.c_str()}}); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); + } + ASSERT_TRUE(1 == countFilesInDirectory(runtime_cache_name)); + + // create new cache + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + std::string new_cache_name = "/tmp/runtime_cache_new/"; + if (std::filesystem::exists(new_cache_name)) { + std::filesystem::remove_all(new_cache_name); + } + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", new_cache_name.c_str()}}); + { + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); + } + // the cache will be dumped to disk upon session destruction + ASSERT_TRUE(std::filesystem::exists(new_cache_name.c_str())); + ASSERT_TRUE(1 == countFilesInDirectory(new_cache_name)); + } +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 739e39a6975e2..1c8cc6f78fe63 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -317,6 +317,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_DisableEpCompile_ThenCompileExplicitly) { Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -355,6 +356,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) { Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -393,6 +395,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelPath(output_model_file); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -427,6 +430,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) { // Create model compilation options from the session options. Output model is stored in a buffer. Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); compile_options.SetInputModelPath(input_model_file); Ort::AllocatorWithDefaultOptions allocator; @@ -482,6 +486,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -515,6 +520,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -573,6 +579,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); compile_options.SetOutputModelExternalInitializersFile(output_initializers_file, 0); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -2070,6 +2077,278 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); } } + +// Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. +static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + std::ofstream* outfile = reinterpret_cast(stream_state); + outfile->write(reinterpret_cast(buffer), buffer_num_bytes); + return nullptr; // No error +} + +// Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error. +static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_num_bytes); + return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to custom write stream +TEST_F(QnnHTPBackendTests, CompileApi_InputFile_WriteOutputModelBytes) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_writeoutputmodelbytes.onnx"); + std::filesystem::remove(input_model_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_inputfile_writeoutputmodelbytes_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Open an output file. Test will incrementally write the output model to file + // via calls to our OrtOutStreamWriteFunc callback. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); + std::ofstream outfile(output_model_file, std::ios::binary); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelWriteFunc(TestWriteToStream, reinterpret_cast(&outfile)); + compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + // Check that the compiled model has the expected number of EPContext nodes. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 2, 2); +} + +// Tests using an OrtOutStreamFunc function that returns an error. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { + // Create a test model (in memory). + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + std::string model_data = test_model.Serialize(); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); + compile_options.SetOutputModelWriteFunc(ReturnStatusFromStream, nullptr); // Set output stream that returns error + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect a specific error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); +} + +struct CustomInitializerHandlerState { + const ORTCHAR_T* external_file_path = nullptr; + std::ofstream* outfile = nullptr; +}; + +static OrtStatus* ORT_API_CALL TestHandleInitializerDataFunc(void* state, + const char* initializer_name, + const OrtValue* c_initializer_value, + const OrtExternalInitializerInfo* /*c_external_info*/, + OrtExternalInitializerInfo** c_new_external_info) { + Ort::Status final_status{nullptr}; + + ORT_TRY { + CustomInitializerHandlerState* custom_state = reinterpret_cast(state); + + if (std::string("constant") == initializer_name) { + // Keep a specific initializer in the model just to test both scenarios. + // A real implementation may check the byte size and keep small initializers in the model. + *c_new_external_info = nullptr; + return nullptr; + } + + // + // Store other initializers in an external file. + // + Ort::ConstValue value{c_initializer_value}; + size_t byte_size = value.GetTensorSizeInBytes(); + int64_t offset = custom_state->outfile->tellp(); + const ORTCHAR_T* location = custom_state->external_file_path; + + custom_state->outfile->write(static_cast(value.GetTensorRawData()), byte_size); + custom_state->outfile->flush(); + + // Provide caller (ORT) with the new external info. + Ort::ExternalInitializerInfo new_external_info{nullptr}; + if (Ort::Status status = Ort::ExternalInitializerInfo::Create(location, offset, byte_size, new_external_info); + !status.IsOK()) { + return status.release(); + } + + *c_new_external_info = new_external_info.release(); + } + ORT_CATCH(const Ort::Exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status{ex}; + })); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status(ex.what(), ORT_FAIL); + })); + } + + return final_status.release(); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to a file +// - Use callback to specify where each initializer is stored (i.e., external file or within model). +TEST_F(QnnHTPBackendTests, CompileApi_InputFile_OutputFile_InitializerHandler) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler_ctx.onnx"); + const ORTCHAR_T* initializer_file = ORT_TSTR("./compileapi_inputfile_outputfile_initializerhandler.bin"); + std::filesystem::remove(input_model_file); + std::filesystem::remove(output_model_file); + std::filesystem::remove(initializer_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Open a file to store external initializers. ORT will call our handler function for every initializer. + ASSERT_FALSE(std::filesystem::exists(initializer_file)); + std::ofstream outfile(initializer_file, std::ios::binary); + CustomInitializerHandlerState custom_state = {initializer_file, &outfile}; + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelGetInitializerLocationFunc(TestHandleInitializerDataFunc, + reinterpret_cast(&custom_state)); + compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); + + ASSERT_TRUE(std::filesystem::exists(initializer_file)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 2, 2); +} + +static OrtStatus* ORT_API_CALL ReuseExternalInitializers(void* state, + const char* /*initializer_name*/, + const OrtValue* /*initializer_value*/, + const OrtExternalInitializerInfo* external_info, + OrtExternalInitializerInfo** new_external_info) { + Ort::Status final_status{nullptr}; + + ORT_TRY { + // If the original initializer was stored in an external file, keep it there (just for testing). + if (external_info != nullptr) { + Ort::ConstExternalInitializerInfo info(external_info); + auto location = info.GetFilePath(); + int64_t offset = info.GetFileOffset(); + size_t byte_size = info.GetByteSize(); + + Ort::ExternalInitializerInfo new_info(nullptr); + Ort::Status status = Ort::ExternalInitializerInfo::Create(location.c_str(), offset, byte_size, new_info); + if (!status.IsOK()) { + return status.release(); + } + + *new_external_info = new_info.release(); + + // Keep track of number of reused external initializers so that we can assert + // that we reused the expected number of initializers. + // THIS IS TEST CODE. An application would not do this. + size_t* num_reused_ext_initializers = reinterpret_cast(state); + *num_reused_ext_initializers += 1; + + return nullptr; + } + + // If not originally external, save it within the generated compiled model + *new_external_info = nullptr; + } + ORT_CATCH(const Ort::Exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status{ex}; + })); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION(([&ex, &final_status]() { + final_status = Ort::Status(ex.what(), ORT_FAIL); + })); + } + + return final_status.release(); +} + +// Test using the CompileModel() API with settings: +// - input model comes from a file +// - write output model to a file +// - Use callback to specify where each initializer is stored. We'll reuse external initializers +// from original model! +TEST_F(QnnHTPBackendTests, CompileApi_InitializerHandler_ReuseExternalInitializers) { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/conv_qdq_external_ini.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("testdata/conv_qdq_external_ini_reuse_ctx.onnx"); + std::filesystem::remove(output_model_file); + + size_t num_reused_ext_initializers = 0; + + // Create model compilation options from the session options. + Ort::SessionOptions so; + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelGetInitializerLocationFunc(ReuseExternalInitializers, + reinterpret_cast(&num_reused_ext_initializers)); + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + std::filesystem::remove(output_model_file); + + ASSERT_EQ(num_reused_ext_initializers, 2); // Reused external conv weight and bias. +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index b102676860444..e46cdb4f98850 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -225,6 +225,199 @@ def test_compile_from_buffer_to_buffer(self): self.assertTrue(isinstance(output_model_bytes, bytes)) self.assertGreater(len(output_model_bytes), 0) + def test_compile_graph_optimization_level(self): + """ + Tests compiling a model with no optimizations (default) vs all optimizations. + """ + input_model_path = get_name("test_cast_back_to_back_non_const_mixed_types_origin.onnx") + output_model_path_0 = os.path.join(self._tmp_dir_path, "cast.disable_all.compiled.onnx") + output_model_path_1 = os.path.join(self._tmp_dir_path, "cast.enable_all.compiled.onnx") + + # Local function that compiles a model with a given graph optimization level and returns + # the count of operator types in the compiled model. + def compile_and_get_op_counts( + output_model_path: str, + graph_opt_level: onnxrt.GraphOptimizationLevel | None, + ) -> dict[str, int]: + session_options = onnxrt.SessionOptions() + if graph_opt_level is not None: + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + graph_optimization_level=graph_opt_level, + ) + else: + # graph optimization level defaults to ORT_DISABLE_ALL if not provided. + model_compiler = onnxrt.ModelCompiler(session_options, input_model_path) + + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + model: onnx.ModelProto = onnx.load(get_name(output_model_path)) + op_counts = {} + for node in model.graph.node: + if node.op_type not in op_counts: + op_counts[node.op_type] = 1 + else: + op_counts[node.op_type] += 1 + + return op_counts + + # Compile model on CPU with no graph optimizations (default). + # Model should have 9 Casts + op_counts_0 = compile_and_get_op_counts(output_model_path_0, graph_opt_level=None) + self.assertEqual(op_counts_0["Cast"], 9) + + # Compile model on CPU with ALL graph optimizations. + # Model should have less casts (optimized out) + op_counts_1 = compile_and_get_op_counts( + output_model_path_1, graph_opt_level=onnxrt.GraphOptimizationLevel.ORT_ENABLE_BASIC + ) + self.assertEqual(op_counts_1["Cast"], 8) + + def test_compile_from_file_to_stream(self): + """ + Tests compiling a model (from files) to an output stream using a custom write functor. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx") + + with open(output_model_path, "wb") as output_fd: + # User's custom write functor. Writes the model to a file. + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + output_fd.write(buffer) + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_stream(my_write_func) + + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_to_stream_that_raises_exception(self): + """ + Tests compiling a model to an output stream that always raises an exception. + """ + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + # User's custom write functor that raises an exception. + test_py_error_message = "My Python Error" + + def my_write_func(buffer: bytes): + self.assertGreater(len(buffer), 0) + raise ValueError(test_py_error_message) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + + # Try to compile and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_stream(my_write_func) + self.assertIn(test_py_error_message, str(context.exception)) + + def test_compile_with_basic_initializer_location_func(self): + """ + Tests compiling a model using a custom initializer handler that stores initializers + in an external file. + """ + input_model_path = get_name("conv_qdq_external_ini.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.onnx") + initializer_file_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler.bin") + + if os.path.exists(output_model_path): + os.remove(output_model_path) + + if os.path.exists(initializer_file_path): + os.remove(initializer_file_path) + + with open(initializer_file_path, "wb") as ext_init_file: + + def store_large_initializer_externally( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + self.assertTrue(initializer_name) # Should have valid name + byte_size = initializer_value.tensor_size_in_bytes() + + if byte_size < 64: + return None # Store small initializer within compiled model. + + # Else, write initializer to new external file. + value_np = initializer_value.numpy() + file_offset = ext_init_file.tell() + ext_init_file.write(value_np.tobytes()) + return onnxrt.OrtExternalInitializerInfo(initializer_file_path, file_offset, byte_size) + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + get_initializer_location_func=store_large_initializer_externally, + ) + model_compiler.compile_to_file(output_model_path) + + self.assertTrue(os.path.exists(output_model_path)) + self.assertTrue(os.path.exists(initializer_file_path)) + + def test_compile_with_initializer_func_that_reuses(self): + """ + Tests compiling a model using a custom initializer handler that reuses external initializer files. + """ + input_model_path = get_name("conv_qdq_external_ini.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "conv_qdq.init_handler_reuse.onnx") + + if os.path.exists(output_model_path): + os.remove(output_model_path) + + # Function that reuses external initializer files for the compiled model. + def reuse_external_initializers( + initializer_name: str, + initializer_value: onnxrt.OrtValue, + external_info: onnxrt.OrtExternalInitializerInfo | None, + ) -> onnxrt.OrtExternalInitializerInfo | None: + self.assertTrue(initializer_name) # Should have valid name + self.assertNotEqual(initializer_value.data_ptr(), 0) + self.assertGreater(initializer_value.tensor_size_in_bytes(), 0) + if external_info is not None: + # Original initializer is stored externally. + # Make the initializer in the compiled model use the same external file + return external_info + + return None # Otherwise, make a copy of the initializer and store it within compiled model. + + session_options = onnxrt.SessionOptions() + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + get_initializer_location_func=reuse_external_initializers, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + def test_fail_load_uncompiled_model_and_then_compile(self): """ Tests compiling scenario: diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_moe_cuda.py similarity index 53% rename from onnxruntime/test/python/transformers/test_parity_moe.py rename to onnxruntime/test/python/transformers/test_moe_cuda.py index 252d89a2257fc..c09d8bacf1fa2 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -9,6 +9,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import itertools +import os import unittest from collections import OrderedDict @@ -21,38 +23,54 @@ import onnxruntime +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + +onnxruntime.preload_dlls() + +# Determine the execution provider and device based on CUDA availability. +use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") +ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] + torch.manual_seed(42) numpy.random.seed(42) -USE_QUANT = False -ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 5e-1 if USE_QUANT else 1e-2 +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} -def value_string_of(numpy_array): - arr = numpy_array.flatten() - lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] - return "{\n " + "f,\n ".join(lines) + "f}" +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} -def print_tensor(name, numpy_array): - print(f"const std::vector {name} = {value_string_of(numpy_array)};") +def quant_dequant(weights, is_4_bit_quantization: bool = True): + type = torch.quint4x2 if is_4_bit_quantization else torch.int8 + import tensorrt_llm # noqa: PLC0415 -def quant_dequant(weights, quant_mode: bool = True): - # use the test version `_symmetric_...` to get the non-interleaved weights - type = torch.quint4x2 if quant_mode else torch.int8 - # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() - # Comment out this line for passing the lintrunner check in the CI. - # import tensorrt_llm + # Avoid lint false alert that the package is not used. Note that this function will not be called in pipeline. + if pipeline_mode: + print("Tensorrt LLM version", tensorrt_llm.__version__) quant_weights, processed_q_weight, torch_weight_scales = ( torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) ) # Unpack the int4s int int8s - if quant_mode: + if is_4_bit_quantization: upper = quant_weights >> 4 lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) @@ -71,6 +89,7 @@ def create_moe_onnx_graph( fc1_experts_bias, fc2_experts_weights, fc2_experts_bias, + onnx_dtype, ): nodes = [ helper.make_node( @@ -94,21 +113,21 @@ def create_moe_onnx_graph( fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] @@ -119,35 +138,35 @@ def create_moe_onnx_graph( [ helper.make_tensor( "fc1_experts_bias", - ORT_DTYPE, + onnx_dtype, fc1_bias_shape, - fc1_experts_bias.to(torch_type).flatten().tolist(), + fc1_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_bias", - ORT_DTYPE, + onnx_dtype, fc2_bias_shape, - fc2_experts_bias.to(torch_type).flatten().tolist(), + fc2_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -171,6 +190,7 @@ def create_mixtral_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, + onnx_dtype, ): nodes = [ helper.make_node( @@ -197,46 +217,46 @@ def create_mixtral_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE, + onnx_dtype, fc3_shape, - fc3_experts_weights.to(torch_type).flatten().tolist(), + fc3_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -259,12 +279,14 @@ def create_phi_moe_onnx_graph( fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, - fc1_scales, - fc2_scales, - fc3_scales, topk, + onnx_dtype, + quant_bits=0, + fc1_scales=None, + fc2_scales=None, + fc3_scales=None, ): - use_quant = USE_QUANT + use_quant = quant_bits > 0 if use_quant: assert fc1_experts_weights.dtype == torch.int8 assert fc2_experts_weights.dtype == torch.int8 @@ -276,34 +298,37 @@ def create_phi_moe_onnx_graph( assert fc2_scales.dtype == torch.float16 assert fc3_scales.dtype == torch.float16 + op_name = "QMoE" if use_quant else "MoE" + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + ) + nodes = [ helper.make_node( - "MoE" if not use_quant else "QMoE", - ( - [ - "input", - "router_probs", - "fc1_experts_weights", - "", - "fc2_experts_weights", - "", - "fc3_experts_weights", - ] - if not use_quant - else [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - "fc3_experts_weights", - "fc3_scales", - "", - ] - ), + op_name, + inputs, ["output"], "MoE_0", k=topk, @@ -315,37 +340,38 @@ def create_phi_moe_onnx_graph( ] if use_quant: - nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)]) + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - fc1_shape = [num_experts, hidden_size, inter_size] - fc2_shape = [num_experts, inter_size, hidden_size] - fc3_shape = [num_experts, hidden_size, inter_size] + components = 2 if quant_bits == 4 else 1 + fc1_shape = [num_experts, hidden_size, inter_size // components] + fc2_shape = [num_experts, inter_size, hidden_size // components] + fc3_shape = [num_experts, hidden_size, inter_size // components] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 - if use_quant: - numpy_type = numpy.uint8 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + + weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc3_shape, - fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc3_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), ] @@ -358,42 +384,42 @@ def create_phi_moe_onnx_graph( [ helper.make_tensor( "fc1_scales", - ORT_DTYPE, + onnx_dtype, fc1_scale_shape, - fc1_scales.to(torch_type).flatten().tolist(), + fc1_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_scales", - ORT_DTYPE, + onnx_dtype, fc2_scale_shape, - fc2_scales.to(torch_type).flatten().tolist(), + fc2_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_scales", - ORT_DTYPE, + onnx_dtype, fc3_scale_shape, - fc3_scales.to(torch_type).flatten().tolist(), + fc3_scales.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -546,126 +572,127 @@ def __init__(self, config: PhiMoEConfig): class SparseMoeBlockORTHelper(nn.Module): - def __init__(self): + def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415 sess_options = SessionOptions() + sess_options.log_severity_level = 2 - cuda_providers = ["CUDAExecutionProvider"] - if cuda_providers[0] not in onnxruntime.get_available_providers(): + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") return None - sess_options.log_severity_level = 2 - ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) - return ort_session def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pass - def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states_flat = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), - } + router_logits = self.gate(hidden_states_flat) - ort_output = None - if self.ort_sess is not None: - if not iobinding: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits - else: - self.ort_run_with_iobinding(ort_inputs) - return None + # Determine the correct torch dtype from the onnx_dtype + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) - - return None + # Prepare tensors on the correct device for ORT inference with the CORRECT dtype + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } - def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + # Bind inputs and outputs to torch tensors directly. iobinding = self.ort_sess.io_binding() - device_id = torch.cuda.current_device() - - iobinding.bind_input( - name="input", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), - ) - iobinding.bind_input( - name="router_probs", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["router_probs"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - ort_inputs["router_probs"], "cuda", device_id - ).data_ptr(), - ) - - iobinding.bind_output( - name="output", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - numpy.zeros(ort_inputs["input"].shape), "cuda", device_id - ).data_ptr(), - ) - - # warm up - for _ in range(5): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - - import time # noqa: PLC0415 - - s = time.time() - for _ in range(repeat): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - e = time.time() - print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + for name, tensor in tensors.items(): + # Ensure tensor is on the globally defined device + if name == "output": + iobinding.bind_output( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + else: + iobinding.bind_input( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + if enable_performance_test: + import time # noqa: PLC0415 + + repeat = 1000 + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + # The output tensor is on `device`. Reshape and return it. + return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) def parity_check(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + # Maps "ort_type:quant_bits" to (atol, rtol) + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (3.0, 1e-2), + "FP16:8": (2.0, 1e-2), + "BF16:0": (1.0, 1e-2), + "BF16:4": (30.0, 1e-1), + "BF16:8": (20.0, 1e-1), + } + + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] if ort_output is not None: print( - "name:", - self.__class__.__name__, - " batch_size:", - self.batch_size, - " sequence_length:", - self.sequence_length, - " max_diff:", - (torch_output - ort_output).abs().max(), + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" + ) + torch.testing.assert_close( + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol ) - torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) - self.ort_forward(hidden_state, iobinding=True) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) class SwitchMoE(SparseMoeBlockORTHelper): @@ -680,7 +707,7 @@ def __init__( eval_capacity=-1, activation="gelu", ): - super().__init__() + super().__init__(quant_bits=0) # SwitchMoE is not quantized self.batch_size = batch_size self.sequence_length = sequence_length self.num_experts = num_experts @@ -709,6 +736,7 @@ def __init__( self.moe_experts.bias1, self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -744,7 +772,7 @@ class MixtralSparseMoeBlock(SparseMoeBlockORTHelper): """ def __init__(self, config, batch_size, sequence_length): - super().__init__() + super().__init__(quant_bits=0) # Mixtral test is not quantized self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -778,6 +806,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -874,40 +903,41 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): and memory on padding. """ - def __init__(self, config, batch_size, sequence_length): - super().__init__() + def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None): + super().__init__(quant_bits, onnx_dtype) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise + use_quant = self.quant_bits > 0 # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - w1_list = [] - w2_list = [] - w3_list = [] - w1_scale_list = [] - w2_scale_list = [] - w3_scale_list = [] - if not USE_QUANT: + w1_list, w2_list, w3_list = [], [], [] + w1_scale_list, w2_scale_list, w3_scale_list = [], [], [] + + if not use_quant: for i in range(self.num_experts): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) else: + is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + # Corrected quantization logic for per-output-channel quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) self.experts[i].w1.weight.data = w1_qdq self.experts[i].w2.weight.data = w2_qdq self.experts[i].w3.weight.data = w3_qdq + # Transpose quantized weights to match the expected ONNX layout w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) w3_list.append(pre_qweight3) @@ -919,9 +949,9 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if use_quant else None self.batch_size = batch_size self.sequence_length = sequence_length @@ -933,10 +963,12 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight1, self.moe_experts_weight2, self.moe_experts_weight3, + self.top_k, + self.onnx_dtype, + self.quant_bits, moe_experts_weight_scale1, moe_experts_weight_scale2, moe_experts_weight_scale3, - self.top_k, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -995,18 +1027,10 @@ def small_test_cases(): yield batch_size, sequence_length -def phi3_test_cases(): - # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. - for batch_size in [1, 4, 16]: - for sequence_length in [128]: - yield batch_size, sequence_length - - +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_switch_moe_parity(self, batch_size, sequence_length): - # if platform.system() == "Windows": - # pytest.skip("Skip on Windows") switch_moe = SwitchMoE( batch_size=batch_size, sequence_length=sequence_length, @@ -1015,26 +1039,412 @@ def test_switch_moe_parity(self, batch_size, sequence_length): hidden_features=1024, out_features=256, ) + switch_moe.to(device) switch_moe.parity_check() - # switch_moe.benchmark_ort() +# quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) +# since qMoE test requires tensorrt_llm for quant_dequant. We disable it in CI pipeline to avoid extra dependency. +quant_bits_list = [0] if pipeline_mode else [0, 8, 4] + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestMixtralMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_mixtral_moe_parity(self, batch_size, sequence_length): config = MixtralConfig(hidden_size=256, intermediate_size=1024) mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.to(device) mixtral_moe.parity_check() - # mixtral_moe.benchmark_ort() +phi3_test_cases = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + quant_bits_list, + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestPhiMoE(unittest.TestCase): - @parameterized.expand(phi3_test_cases()) - def test_phi3_moe_parity(self, batch_size, sequence_length): + @parameterized.expand(phi3_test_cases) + def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits) + phi3_moe.to(device) phi3_moe.parity_check() - # phi3_moe.benchmark_ort() + + +# --------------------------------------------- +# The following test are for swiglu activation +# --------------------------------------------- +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +# Note that the weight shape might not match the tensor shape in legacy operator spec. +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + onnx_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, 2 * inter_size, hidden_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, hidden_size, inter_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] + + initializers = [ + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), + ] + + if use_quant: + initializers.extend( + [ + make_onnx_intializer( + "fc1_experts_weight_scale", + fc1_experts_weight_scale.to(torch_dtype), + fc1_experts_weight_scale_shape, + onnx_dtype, + ), + make_onnx_intializer( + "fc2_experts_weight_scale", + fc2_experts_weight_scale.to(torch_dtype), + fc2_experts_weight_scale_shape, + onnx_dtype, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + # For the ONNX MoE operator, weights must be transposed to [In, Out] format. + # Biases do not require transposition. + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + # quant_dequant expects [Out, In] format, matching nn.Linear.weight + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + # Update the expert's weight with the dequantized version for the PyTorch reference. + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + # Stack the prepared tensors for the graph builder + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + # Build the ONNX graph with the correctly shaped tensors + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + onnx_dtype=self.onnx_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, + fc1_experts_weight_scale=moe_experts_weight_scale1, + fc2_experts_weight_scale=moe_experts_weight_scale2, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This is the robust PyTorch reference implementation. It directly uses the + nn.Module experts, which is cleaner and less error-prone than manual matmul. + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_cases = list( + itertools.product( + [1, 2], # batch_size + [1, 3], # sequence_length + quant_bits_list, # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") +class TestSwigluMoE(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +def has_bf16_moe(): + if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 + + +@unittest.skipIf(not has_bf16_moe(), "skipping bf16 moe tests.") +class TestSwigluMoeBf16(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=128, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits, onnx_dtype=TensorProto.BFLOAT16) + moe.to(device) + moe.parity_check() + + +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128, 512, 1024, 2048, 4096], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(pipeline_mode or not use_cuda, "skipping performance test in CI pipeline.") +class TestSwigluMoEPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + hidden_size = 2880 + intermediate_size = 2880 + num_experts_per_token = 8 + num_local_experts = 128 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py new file mode 100644 index 0000000000000..efaaca29a01b6 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -0,0 +1,1118 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# +# QMoE quantization implementation notes: +# +# Both CPU and CUDA implementations use symmetric quantization centered around 0: +# - 4-bit: range [-8, 7] with no zero-point (symmetric around 0) +# - 8-bit: range [-128, 127] with no zero-point (symmetric around 0) +# +# This follows the _symmetric_quantize_last_axis_of_batched_matrix pattern. +# Tolerance values account for numerical differences between implementations. +# +# Routing Logic: CPU implementation uses top-k selection first, then softmax +# normalization on the selected experts. This provides proper weight distribution +# while maintaining computational efficiency. +# -------------------------------------------------------------------------- +import time +import unittest +from collections import OrderedDict + +import numpy +import torch +import torch.nn.functional as F +from onnx import helper +from parameterized import parameterized +from torch import nn + +import onnxruntime + +try: + from onnx import TensorProto + + has_onnx = True +except ImportError: + has_onnx = False + + class TensorProtoPlaceholder: + FLOAT16 = 10 + FLOAT = 1 + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, + "gelu": nn.GELU, +} +ACT2FN = ClassInstantier(ACT2CLS) + +if not has_onnx: + + class TensorProtoPlaceholder: + FLOAT16 = 10 + FLOAT = 1 + UINT8 = 2 + + TensorProto = TensorProtoPlaceholder + +onnxruntime.preload_dlls() + +device = torch.device("cpu") + +ort_provider = ["CPUExecutionProvider"] + +torch.manual_seed(42) +numpy.random.seed(42) + +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.UINT8: torch.uint8, +} + +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} + +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", +} + + +def quant_dequant(weights, is_4_bit_quantization: bool = True): + """ + Quantize and dequantize weights for testing purposes. + This function uses symmetric quantization centered around 0 (no zero-point). + + This uses symmetric quantization similar to _symmetric_quantize_last_axis_of_batched_matrix: + - 4-bit: range = [-8, 7], no zero-point (symmetric around 0) + - 8-bit: range = [-128, 127], no zero-point (symmetric around 0) + """ + # Handle edge case of all-zero weights tensor + if torch.all(weights == 0): + if is_4_bit_quantization: + packed_size = (weights.shape[-1] + 1) // 2 + return ( + torch.zeros_like(weights[..., 0:1]), + torch.zeros( + (weights.shape[0], weights.shape[1], packed_size), + dtype=torch.uint8, + device=weights.device, + ), + torch.zeros_like(weights), + ) + else: + return ( + torch.zeros_like(weights[..., 0:1]), + torch.zeros_like(weights, dtype=torch.uint8), + torch.zeros_like(weights), + ) + + # Calculate scale like C++ implementation + abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision + + if is_4_bit_quantization: + # 4-bit: scale = abs_max / 7.0 (using 7.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 7.0).float() + 1e-12 + + # Handle potential edge cases for zero or very small weights + if torch.max(abs_max) < 1e-8: + packed_size = (weights.shape[-1] + 1) // 2 + return ( + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros( + (weights.shape[0], weights.shape[1], packed_size), + dtype=torch.uint8, + device=weights.device, + ), + torch.zeros_like(weights), + ) + + # Quantize: round(weight / scale) then clamp to [-8, 7] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-8, 7).float() + + # For symmetric quantization, we use signed int4 representation + # Convert to uint8 storage for packing: shift [-8,7] -> [0,15] for storage only + storage_weights = (quantized_weights + 8).to(torch.uint8) + + # Pack 4-bit values into uint8 (every two elements) + even_indices = torch.arange(0, weights.shape[-1], 2) + odd_indices = torch.arange(1, weights.shape[-1], 2) + + # Handle odd length by padding with zero (which is 8 in storage representation) + if odd_indices.shape[0] < even_indices.shape[0]: + padding = torch.full( + (storage_weights.shape[0], storage_weights.shape[1], 1), + fill_value=8, # 0 in symmetric quantization, stored as 8 + dtype=torch.uint8, + device=storage_weights.device, + ) + storage_weights = torch.cat([storage_weights, padding], dim=-1) + odd_indices = torch.arange(1, storage_weights.shape[-1], 2) + + even_weights = storage_weights[..., even_indices] + odd_weights = storage_weights[..., odd_indices] + + # Pack: low nibble = even, high nibble = odd + packed_weights = (even_weights & 0xF) | ((odd_weights & 0xF) << 4) + + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Unpack for dequantization + lower = packed_weights & 0xF + upper = (packed_weights >> 4) & 0xF + + # Restore original shape and convert back to signed representation + unpacked_weights = torch.zeros_like(weights, dtype=torch.uint8) + unpacked_weights[..., even_indices] = lower + + valid_odd_length = min(odd_indices.shape[0], weights.shape[-1] - even_indices.shape[0]) + if valid_odd_length > 0: + valid_odd_indices = odd_indices[:valid_odd_length] + unpacked_weights[..., valid_odd_indices] = upper[..., :valid_odd_length] + + # Convert back to signed values: [0,15] -> [-8,7] and apply scale + signed_weights = unpacked_weights.float() - 8.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), packed_weights, result.to(weights.dtype) + else: + # 8-bit: scale = abs_max / 127.0 (using 127.0 as max positive value for symmetric range) + # Use higher precision computation for better accuracy + scale = (abs_max.double() / 127.0).float() + 1e-12 + + # Handle potential edge cases for zero or very small weights + if torch.max(abs_max) < 1e-8: + return ( + torch.ones_like(weights[..., 0:1]) * 1e-8, + torch.zeros_like(weights, dtype=torch.uint8), + torch.zeros_like(weights), + ) + + # Quantize: round(weight / scale) then clamp to [-128, 127] + # Use higher precision for the division to reduce accumulated errors + scaled_weights = weights.double() / scale.double() + quantized_weights = torch.round(scaled_weights).clamp(-128, 127).float() + + # For symmetric quantization, we use signed int8 representation + # Convert to uint8 storage: shift [-128,127] -> [0,255] for storage only + storage_weights = (quantized_weights + 128).to(torch.uint8) + + # Dequantize: scale * quantized_value (no zero-point subtraction) + # Convert back to signed values: [0,255] -> [-128,127] and apply scale + signed_weights = storage_weights.float() - 128.0 # Convert storage back to signed + dequant_scale = scale.float() # Ensure FP32 precision for computation + result = dequant_scale * signed_weights # No zero-point in symmetric quantization + + return scale.to(torch.float16), storage_weights, result.to(weights.dtype) + + +def create_cpu_moe_onnx_graph( + hidden_size, + sequence_length, + num_experts, + top_k, + intermediate_size, + torch_dtype, + onnx_dtype, + fc1_experts_weights, + fc2_experts_weights, + fc1_bias=None, + fc2_bias=None, + fc1_scales=None, + fc2_scales=None, + use_swiglu=False, + use_quant=False, + quant_bits=4, + swiglu_interleaved=False, +): + if not has_onnx: + return None + + inter_size = intermediate_size + topk = top_k + + use_quant = True + + if fc1_scales is None and use_quant: + return None + if fc2_scales is None and use_quant: + return None + if not has_onnx: + return None + + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" + assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + + if not has_onnx: + return None + + op_name = "QMoE" + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + ] + + activation = "swiglu" if use_swiglu else "silu" + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior + activation_type=activation, + # Add new attributes with backwards-compatible default values + swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_limit=7.0, + activation_alpha=1.702, + activation_beta=1.0, + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + # Weights are store in column major order. Need pack 2 int4 values into uint8. + # Use the actual tensor shapes instead of calculating them to avoid size mismatches + fc1_shape = list(fc1_experts_weights.shape) + fc2_shape = list(fc2_experts_weights.shape) + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + + weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + weight_onnx_type, + fc1_shape, + fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + ), + helper.make_tensor( + "fc2_experts_weights", + weight_onnx_type, + fc2_shape, + fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + ), + ] + + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] + fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) + + # Process scale tensor for proper shape + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [sequence_length, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +class PhiMoEConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + hidden_act="silu", + num_experts_per_tok=2, + num_local_experts=8, + router_jitter_noise=0.01, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_jitter_noise = router_jitter_noise + + +class SwigluMoeConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + num_local_experts=8, + num_experts_per_token=2, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_local_experts = num_local_experts + self.num_experts_per_token = num_experts_per_token + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class MoEBlockSparseTop2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): + def __init__(self, config: PhiMoEConfig): + super().__init__(config) + + +class PhiMoESwiGLUMLP(nn.Module): + """ + Phi3 MoE expert converted to 2-weight SwiGLU structure for CPU compatibility. + This converts the traditional 3-weight Phi3 structure to SwiGLU format. + """ + + def __init__(self, config: PhiMoEConfig): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): + """ + Updated to match the CUDA implementation's routing logic for fair comparison. + This now uses the same complex jitter-based masking approach as the CUDA tests. + """ + assert top_k == 2 + assert not training + + mask_logits_threshold, selected_experts = torch.topk(scores, 2) + + mask_logits_threshold_1 = mask_logits_threshold[:, 0].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_1) + logits_mask = ((mask_logits_threshold_1 - scores) / factor) > (2 * jitter_eps) + + multiplier_1 = torch.softmax(scores.masked_fill(logits_mask, float("-inf")), dim=-1).gather( + dim=-1, index=selected_experts[:, 0].unsqueeze(-1) + ) + + mask_logits_threshold_2 = mask_logits_threshold[:, 1].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_2) + logits_mask = ((mask_logits_threshold_2 - scores) / factor) > (2 * jitter_eps) + + multiplier_2 = torch.softmax( + torch.scatter(scores, -1, selected_experts[:, 0].unsqueeze(-1), float("-inf")).masked_fill( + logits_mask, float("-inf") + ), + dim=-1, + ).gather(dim=-1, index=selected_experts[:, 1].unsqueeze(-1)) + + multiplier = torch.concat((multiplier_1, multiplier_2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +class SparseMoeBlockORTHelper(nn.Module): + def __init__(self, quant_bits=0, onnx_dtype=None): + super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 + + def create_ort_session(self, moe_onnx_graph): + if moe_onnx_graph is None: + return None + + sess_options = onnxruntime.SessionOptions() + sess_options.log_severity_level = 2 + + try: + ort_session = onnxruntime.InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception: + return None + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pass + + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_flat) + + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] + + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } + + try: + iobinding = self.ort_sess.io_binding() + + for name, tensor in tensors.items(): + if name == "output": + iobinding.bind_output( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + else: + iobinding.bind_input( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + if enable_performance_test: + repeat = 100 + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + time_ms = (e - s) / repeat * 1000 + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") + + return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) + + except Exception as e: + raise + + def recreate_onnx_model(self): + """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" + + w1_list, w2_list = [], [] + w1_scale_list, w2_scale_list = [], [] + + is_4_bit = self.quant_bits == 4 + for i in range(self.num_experts): + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + + if self.use_swiglu: + if self.swiglu_interleaved: + pass + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + + gate_weights = pre_qweight1 + value_weights = pre_qweight3 + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() + + self.experts[i].w2.weight.data = w2_qdq.contiguous().clone() + + w1_list.append(pre_qweight1) + w2_list.append(pre_qweight2) + w1_scale_list.append(w1_scale) + w2_scale_list.append(w2_scale) + + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(w2_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) + + if moe_experts_weight_scale1.dim() == 3: + moe_experts_weight_scale1 = moe_experts_weight_scale1.squeeze(-1) + if moe_experts_weight_scale2.dim() == 3: + moe_experts_weight_scale2 = moe_experts_weight_scale2.squeeze(-1) + + try: + self.moe_onnx_graph = create_cpu_moe_onnx_graph( + hidden_size=self.hidden_dim, + sequence_length=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + top_k=self.top_k, + intermediate_size=self.ffn_dim, + torch_dtype=torch.float32, + onnx_dtype=self.onnx_dtype, + fc1_experts_weights=self.moe_experts_weight1, + fc2_experts_weights=self.moe_experts_weight2, + # Biases are not used in QMoE + fc1_bias=None, + fc2_bias=None, + # Scales are used for dequantization + fc1_scales=moe_experts_weight_scale1, + fc2_scales=moe_experts_weight_scale2, + use_swiglu=self.use_swiglu, + use_quant=True, # Always use QMoE + quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + ) + except Exception: + self.moe_onnx_graph = None + return False + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + return self.ort_sess is not None + + def parity_check(self): + model_updated = self.recreate_onnx_model() + if not model_updated: + return + + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + if ort_output is None: + return + + torch_has_nan = torch.isnan(torch_output).any() + ort_has_nan = torch.isnan(ort_output).any() + torch_has_inf = torch.isinf(torch_output).any() + ort_has_inf = torch.isinf(ort_output).any() + + if torch_has_nan or ort_has_nan or torch_has_inf or ort_has_inf: + torch_output_clean = torch.where( + torch.isnan(torch_output) | torch.isinf(torch_output), torch.zeros_like(torch_output), torch_output + ) + ort_output_clean = torch.where( + torch.isnan(ort_output) | torch.isinf(ort_output), torch.zeros_like(ort_output), ort_output + ) + max_diff = (torch_output_clean.cpu() - ort_output_clean.cpu()).abs().max() + + if (torch_has_nan and ort_has_nan) or (torch_has_inf and ort_has_inf): + problematic_torch = torch.isnan(torch_output) | torch.isinf(torch_output) + problematic_ort = torch.isnan(ort_output) | torch.isinf(ort_output) + if torch.equal(problematic_torch, problematic_ort): + max_diff = 0.0 + else: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() + + is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu + is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" + + print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (0.05, 0.01), + "FP16:8": (0.02, 0.01), + "FP32:4": (0.11, 0.01), + "FP32:8": (0.11, 0.01), + } + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + tolerance_key = f"{dtype_str}:{self.quant_bits}" + if tolerance_key in ort_dtype_quant_bits_tolerance_map: + base_atol, rtol = ort_dtype_quant_bits_tolerance_map[tolerance_key] + + if max_diff > base_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"tolerance {base_atol:.6f} for {tolerance_key}" + ) + else: + fallback_atol = 0.1 + if max_diff > fallback_atol: + raise AssertionError( + f"QMoE parity check failed: max difference {max_diff:.6f} exceeds " + f"fallback tolerance {fallback_atol:.6f} for unknown config {tolerance_key}" + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +def small_test_cases(): + for batch_size in [1, 4]: + for sequence_length in [32, 128]: + yield batch_size, sequence_length + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = None + self.ort_sess = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.router_jitter_noise = config.router_jitter_noise + self.use_swiglu = True + self.swiglu_interleaved = True + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([PhiMoESwiGLUMLP(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_cpu_moe_onnx_graph( + hidden_size=self.hidden_dim, + sequence_length=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + top_k=self.top_k, + intermediate_size=self.ffn_dim, + torch_dtype=torch.float32, + onnx_dtype=self.onnx_dtype, + fc1_experts_weights=fc1_experts_weights, + fc2_experts_weights=fc2_experts_weights, + fc1_bias=fc1_experts_bias, + fc2_bias=fc2_experts_bias, + fc1_scales=moe_experts_weight_scale1, + fc2_scales=moe_experts_weight_scale2, + use_swiglu=self.use_swiglu, + use_quant=use_quant, + quant_bits=self.quant_bits, + swiglu_interleaved=self.swiglu_interleaved, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """PyTorch reference forward pass using SwiGLU-style routing""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +disable_cpu_qmoe_tests = False + +# Define test cases for different MoE types +phi3_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] + + +@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") +class TestPhiQMoECPU(unittest.TestCase): + @parameterized.expand(phi3_test_cases) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running Phi3 QMoE test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + + +disable_cpu_qmoe_tests = False + +swiglu_test_cases = [ + (1, 32, 4), + (1, 32, 8), + (2, 16, 4), + (2, 16, 8), +] + + +@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") +class TestSwigluQMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + print(f"Running SwiGLU test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = swiglu_moe.forward(hidden_states) + + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + swiglu_moe.parity_check() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 786c0ba713b85..b7a9da8e1b658 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1901,14 +1901,6 @@ TEST(CApiTest, test_pyop_kwarg) { } #endif -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -TEST(CApiTest, create_session_without_session_option) { - constexpr PATH_TYPE model_uri = TSTR("../models/opset8/test_squeezenet/model.onnx"); - Ort::Session ret(*ort_env, model_uri, Ort::SessionOptions{nullptr}); - ASSERT_NE(nullptr, ret); -} -#endif - #ifdef REDUCED_OPS_BUILD TEST(ReducedOpsBuildTest, test_excluded_ops) { // In reduced ops build, test a model containing ops not included in required_ops.config cannot be loaded. diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 0fe747cdd84e5..cffa0efc39d45 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -420,7 +420,7 @@ TEST(ModelEditorAPITest, BasicModelEdit_CxxApi) { // typically this isn't needed. we replace this input but need to read info from it later on in the test // validation so we save the info locally to keep it accessible. - auto orig_input_name = graph_inputs[0].Name(); + auto orig_input_name = graph_inputs[0].GetName(); auto input_shape = graph_inputs[0].TypeInfo().GetTensorTypeAndShapeInfo().GetShape(); const std::string new_input_name = "Int64Input"; @@ -589,7 +589,7 @@ TEST(ModelEditorAPITest, InvalidModelEdit) { Node node("Cast", domain, "NewInputNode", {new_input_name}, // the existing node will now consume the output from the Cast instead of a graph input - {graph_inputs[0].Name()}, + {graph_inputs[0].GetName()}, attributes); graph.AddNode(node);