Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public class OrtModelCompilationOptions : SafeHandle
/// <summary>
/// Create a new OrtModelCompilationOptions object from SessionOptions.
/// </summary>
/// <remarks>By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel()
/// to enable graph optimizations.</remarks>
/// <param name="sessionOptions">SessionOptions instance to read settings from.</param>
public OrtModelCompilationOptions(SessionOptions sessionOptions)
: base(IntPtr.Zero, true)
Expand Down Expand Up @@ -130,6 +132,33 @@ public void SetFlags(OrtCompileApiFlags flags)
NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags));
}

/// <summary>
/// Sets information related to EP context binary file. The Ep uses this information to decide the
/// location and context binary file name when compiling with both the input and output models
/// stored in buffers.
/// </summary>
/// <param name="outputDirectory">Path to the model directory.</param>
/// <param name="modelName">The name of the model.</param>
public void SetEpContextBinaryInformation(string outputDirectory, string modelName)
{
var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory);
var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName);
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation(
handle, platformOutputDirectory, platformModelName));
}

/// <summary>
/// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified.
/// </summary>
/// <param name="graphOptimizationLevel">The graph optimization level to set.</param>
public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel)
{
NativeApiStatus.VerifySuccess(
NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel(
handle, graphOptimizationLevel));
}

internal IntPtr Handle => handle;


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public struct OrtCompileApi
public IntPtr ModelCompilationOptions_SetEpContextEmbedMode;
public IntPtr CompileModel;
public IntPtr ModelCompilationOptions_SetFlags;
public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation;
public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel;
}

internal class NativeMethods
Expand Down Expand Up @@ -101,6 +103,21 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile
uint flags);
public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation(
IntPtr /* OrtModelCompilationOptions* */ options,
byte[] /* const ORTCHAR_T* */ outputDirectory,
byte[] /* const ORTCHAR_T* */ modelName);
public DOrtModelCompilationOptions_SetEpContextBinaryInformation
OrtModelCompilationOptions_SetEpContextBinaryInformation;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel(
IntPtr /* OrtModelCompilationOptions* */ options,
GraphOptimizationLevel graphOptimizationLevel);
public DOrtModelCompilationOptions_SetGraphOptimizationLevel
OrtModelCompilationOptions_SetGraphOptimizationLevel;

internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi)
{

Expand Down Expand Up @@ -161,6 +178,16 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi
_compileApi.ModelCompilationOptions_SetFlags,
typeof(DOrtModelCompilationOptions_SetFlags));

OrtModelCompilationOptions_SetEpContextBinaryInformation =
(DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer(
_compileApi.ModelCompilationOptions_SetEpContextBinaryInformation,
typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation));

OrtModelCompilationOptions_SetGraphOptimizationLevel =
(DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(
_compileApi.ModelCompilationOptions_SetGraphOptimizationLevel,
typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel));

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public void BasicUsage()

compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
compileOptions.SetEpContextEmbedMode(true);
compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC);

}

Expand All @@ -45,6 +46,7 @@ public void BasicUsage()
UIntPtr bytesSize = new UIntPtr();
var allocator = OrtAllocator.DefaultInstance;
compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx");

compileOptions.CompileModel();

Expand Down
18 changes: 17 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7074,6 +7074,9 @@ struct OrtCompileApi {
* ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling
* CompileModel.
*
* \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use
* ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations.
*
* \param[in] env OrtEnv object.
* \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions.
* \param[out] out The created OrtModelCompilationOptions instance.
Expand Down Expand Up @@ -7230,7 +7233,7 @@ struct OrtCompileApi {
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options,
size_t flags);
uint32_t flags);

/** Sets information related to EP context binary file.
*
Expand All @@ -7249,6 +7252,19 @@ struct OrtCompileApi {
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_directory,
_In_ const ORTCHAR_T* model_name);

/** Set the graph optimization level.
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] graph_optimization_level The graph optimization level.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level);
};

/*
Expand Down
4 changes: 3 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,9 @@ struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags

ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel
};

/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.
Expand Down
9 changes: 8 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1019,11 +1019,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) {
inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags));
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel(
GraphOptimizationLevel graph_optimization_level) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_,
graph_optimization_level));
return *this;
}

namespace detail {

template <typename T>
Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/core/session/compile_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode
}

ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
_In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) {
_In_ OrtModelCompilationOptions* ort_model_compile_options, uint32_t flags) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);
Expand All @@ -245,6 +245,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* ort_model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);
ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_model_compile_options);
ORT_UNUSED_PARAMETER(graph_optimization_level);
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
#endif // !defined(ORT_MINIMAL_BUILD)
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env,
_In_ const OrtModelCompilationOptions* ort_model_compile_options) {
API_IMPL_BEGIN
Expand Down Expand Up @@ -278,6 +294,7 @@ static constexpr OrtCompileApi ort_compile_api = {

&OrtCompileAPI::ModelCompilationOptions_SetFlags,
&OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
&OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/session/compile_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel
bool embed_ep_context_in_model);
ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options,
size_t flags);
uint32_t flags);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level);

} // namespace OrtCompileAPI
32 changes: 31 additions & 1 deletion onnxruntime/core/session/model_compilation_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment&
// Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions.
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK());
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK());

session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only
}

void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) {
Expand Down Expand Up @@ -135,7 +137,7 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m
return Status::OK();
}

Status ModelCompilationOptions::SetFlags(size_t flags) {
Status ModelCompilationOptions::SetFlags(uint32_t flags) {
EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options;
options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS;
options.action_if_no_compiled_nodes =
Expand Down Expand Up @@ -170,6 +172,34 @@ void ModelCompilationOptions::ResetInputModelSettings() {
input_model_data_size_ = 0;
}

Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
switch (graph_optimization_level) {
case ORT_DISABLE_ALL:
// TransformerLevel::Default means that we only run required transformers.
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
break;
case ORT_ENABLE_BASIC:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
break;
case ORT_ENABLE_EXTENDED:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
break;
case ORT_ENABLE_LAYOUT:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
break;
case ORT_ENABLE_ALL:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel;
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ",
static_cast<int>(graph_optimization_level), " is invalid. Valid values are: ",
"ORT_DISABLE_ALL (0), ORT_ENABLE_BASIC (1), ORT_ENABLE_EXTENDED (2), ",
"ORT_ENABLE_LAYOUT (3), and ORT_ENABLE_ALL (99).");
}

return Status::OK();
}

Status ModelCompilationOptions::ResetOutputModelSettings() {
EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
ep_context_gen_options.output_model_file_path.clear();
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/core/session/model_compilation_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ModelCompilationOptions {
/// </summary>
/// <param name="flags">unsigned integer set to the bitwise OR of enabled flags.</param>
/// <returns>Status indicating success or an error</returns>
Status SetFlags(size_t flags);
Status SetFlags(uint32_t flags);

/// <summary>
/// Returns a reference to the session options object.
Expand Down Expand Up @@ -129,6 +129,13 @@ class ModelCompilationOptions {
/// <returns>input model buffer's size in bytes</returns>
size_t GetInputModelDataSize() const;

/// <summary>
/// Sets the graph optimization level for the underlying session that compiles the model.
/// </summary>
/// <param name="graph_optimization_level">The optimization level</param>
/// <returns></returns>
Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);

/// <summary>
/// Checks if the compilation options described by this object are valid.
/// </summary>
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def __init__(
external_initializers_file_path: str | os.PathLike | None = None,
external_initializers_size_threshold: int = 1024,
flags: int = C.OrtCompileApiFlags.NONE,
graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL,
):
"""
Creates a ModelCompiler instance.
Expand All @@ -663,6 +664,8 @@ def __init__(
is None or empty. Initializers larger than this threshold are stored in the external initializers file.
:param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of
flags in onnxruntime.OrtCompileApiFlags.
:param graph_optimization_level: The graph optimization level.
Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL.
"""
input_model_path: str | os.PathLike | None = None
input_model_bytes: bytes | None = None
Expand Down Expand Up @@ -694,6 +697,7 @@ def __init__(
external_initializers_file_path,
external_initializers_size_threshold,
flags,
graph_optimization_level,
)
else:
self._model_compiler = C.ModelCompiler(
Expand All @@ -704,6 +708,7 @@ def __init__(
external_initializers_file_path,
external_initializers_size_threshold,
flags,
graph_optimization_level,
)

def compile_to_file(self, output_model_path: str | None = None):
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_model_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr<PyModelCompi
bool embed_compiled_data_into_model,
const std::string& external_initializers_file_path,
size_t external_initializers_size_threshold,
size_t flags) {
uint32_t flags,
GraphOptimizationLevel graph_optimization_level) {
auto model_compiler = std::make_unique<PyModelCompiler>(env, sess_options, PrivateConstructorTag{});
ModelCompilationOptions& compile_options = model_compiler->model_compile_options_;

Expand All @@ -43,6 +44,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr<PyModelCompi
ORT_RETURN_IF_ERROR(compile_options.SetFlags(flags));
}

ORT_RETURN_IF_ERROR(compile_options.SetGraphOptimizationLevel(graph_optimization_level));

out = std::move(model_compiler);
return Status::OK();
}
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_model_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ class PyModelCompiler {
/// <param name="embed_compiled_data_into_model">True to embed compiled binary data into EPContext nodes.</param>
/// <param name="external_initializers_file_path">The file into which to store initializers for non-compiled
/// nodes.</param>
/// <param name="flags">Flags from OrtCompileApiFlags</param>
/// <param name="external_initializers_size_threshold">Ignored if 'external_initializers_file_path' is empty.
/// Initializers with a size greater than this threshold are dumped into the external file.</param>
/// <param name="flags">Flags from OrtCompileApiFlags</param>
/// <param name="graph_opt_level">Optimization level for graph transformations on the model.
/// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model.</param>
/// <returns>A Status indicating error or success.</returns>
static onnxruntime::Status Create(/*out*/ std::unique_ptr<PyModelCompiler>& out,
onnxruntime::Environment& env,
Expand All @@ -46,7 +48,8 @@ class PyModelCompiler {
bool embed_compiled_data_into_model = false,
const std::string& external_initializers_file_path = {},
size_t external_initializers_size_threshold = 1024,
size_t flags = 0);
uint32_t flags = 0,
GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL);

// Note: Creation should be done via Create(). This constructor is public so that it can be called from
// std::make_shared().
Expand Down
Loading
Loading