Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Compile API: disable optimizations by default
  • Loading branch information
adrianlizarraga committed Jul 21, 2025
commit c46fd7a5bc6db25162f93e3909e9d670bdcd2337
16 changes: 16 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6900,6 +6900,9 @@ struct OrtCompileApi {
* ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling
* CompileModel.
*
* \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use
* ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations.
*
* \param[in] env OrtEnv object.
* \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions.
* \param[out] out The created OrtModelCompilationOptions instance.
Expand Down Expand Up @@ -7075,6 +7078,19 @@ struct OrtCompileApi {
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_directory,
_In_ const ORTCHAR_T* model_name);

/** Set the graph optimization level.
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] graph_optimization_level The graph optimization level.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level);
};

/*
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,8 @@ struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags

ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel
};

/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,13 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags)
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel(
GraphOptimizationLevel graph_optimization_level) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_,
graph_optimization_level));
return *this;
}

namespace detail {

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/qnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ struct QnnEpFactory : OrtEpFactory {
}

static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr,
const OrtMemoryDevice* memory_device,
const OrtMemoryDevice* /*memory_device*/,
const OrtKeyValuePairs* /*stream_options*/,
OrtSyncStreamImpl** ort_stream) noexcept {
auto& factory = *static_cast<QnnEpFactory*>(this_ptr);
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/session/compile_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* ort_model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);
ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_model_compile_options);
ORT_UNUSED_PARAMETER(graph_optimization_level);
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
#endif // !defined(ORT_MINIMAL_BUILD)
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env,
_In_ const OrtModelCompilationOptions* ort_model_compile_options) {
API_IMPL_BEGIN
Expand Down Expand Up @@ -278,6 +294,7 @@ static constexpr OrtCompileApi ort_compile_api = {

&OrtCompileAPI::ModelCompilationOptions_SetFlags,
&OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
&OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/compile_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOp
size_t flags);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ GraphOptimizationLevel graph_optimization_level);

} // namespace OrtCompileAPI
27 changes: 27 additions & 0 deletions onnxruntime/core/session/model_compilation_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment&
// Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions.
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK());
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK());

session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only
}

void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) {
Expand Down Expand Up @@ -170,6 +172,31 @@ void ModelCompilationOptions::ResetInputModelSettings() {
input_model_data_size_ = 0;
}

Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
switch (graph_optimization_level) {
case ORT_DISABLE_ALL:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
break;
case ORT_ENABLE_BASIC:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
break;
case ORT_ENABLE_EXTENDED:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
break;
case ORT_ENABLE_LAYOUT:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
break;
case ORT_ENABLE_ALL:
session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel;
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ",
static_cast<int>(graph_optimization_level), " is invalid");
}

return Status::OK();
}

Status ModelCompilationOptions::ResetOutputModelSettings() {
EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
ep_context_gen_options.output_model_file_path.clear();
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/model_compilation_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ class ModelCompilationOptions {
/// <returns>input model buffer's size in bytes</returns>
size_t GetInputModelDataSize() const;

/// <summary>
/// Sets the graph optimization level for the underlying session that compiles the model.
/// </summary>
/// <param name="graph_optimization_level">The optimization level</param>
/// <returns></returns>
Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);

/// <summary>
/// Checks if the compilation options described by this object are valid.
/// </summary>
Expand Down
36 changes: 23 additions & 13 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,20 @@ static void CreateTestModel(test::GetTestModelFn graph_builder,
// Helper that checks that a compiled model has the expected number of EPContext nodes.
static void CheckEpContextNodeCounts(const onnxruntime::Model& ep_ctx_model,
int expected_ep_context_node_count,
int expected_other_node_count) {
int expected_other_node_count,
bool optimizations_disabled = false) {
int ep_context_node_count = 0;
int non_ep_context_node_count = 0;
auto& ctx_graph = ep_ctx_model.MainGraph();
for (auto& node : ctx_graph.Nodes()) {
if (node.OpType() == "EPContext") {
++ep_context_node_count;
// validate the fix for the partition issue relate to QDQ model
ASSERT_EQ(node.InputDefs().size(), 1);
// When optimizations are enabled, constant folding (of QuantizeLinear) will ensure all EPContext nodes
// have 1 input. When optimizations are off, 1 EPContext node will have 2 inputs, so don't check.
if (!optimizations_disabled) {
ASSERT_EQ(node.InputDefs().size(), 1);
}
} else {
++non_ep_context_node_count;
}
Expand All @@ -255,22 +260,26 @@ static void CheckEpContextNodeCounts(const onnxruntime::Model& ep_ctx_model,
// Helper to check that a compiled model (stored as a file) has the expected number of EPContext nodes.
static void CheckEpContextNodeCounts(const ORTCHAR_T* model_path,
int expected_ep_context_node_count,
int expected_other_node_count) {
int expected_other_node_count,
bool optimizations_disabled = false) {
std::shared_ptr<Model> ep_ctx_model;
ASSERT_STATUS_OK(Model::Load(ToPathString(model_path), ep_ctx_model, nullptr, DefaultLoggingManager().DefaultLogger()));
CheckEpContextNodeCounts(*ep_ctx_model, expected_ep_context_node_count, expected_other_node_count);
CheckEpContextNodeCounts(*ep_ctx_model, expected_ep_context_node_count, expected_other_node_count,
optimizations_disabled);
}

// Helper to check that a compiled model (stored in a buffer) has the expected number of EPContext nodes.
static void CheckEpContextNodeCounts(void* model_buffer, size_t model_buffer_size,
int expected_ep_context_node_count,
int expected_other_node_count) {
int expected_other_node_count,
bool optimizations_disabled = false) {
std::shared_ptr<Model> ep_ctx_model;
const ORTCHAR_T* output_model_path = ORT_TSTR("tmp_output_ctx_model.onnx");
ASSERT_STATUS_OK(onnxruntime::Model::LoadFromBytes(static_cast<int>(model_buffer_size),
model_buffer, output_model_path, ep_ctx_model,
nullptr, DefaultLoggingManager().DefaultLogger()));
CheckEpContextNodeCounts(*ep_ctx_model, expected_ep_context_node_count, expected_other_node_count);
CheckEpContextNodeCounts(*ep_ctx_model, expected_ep_context_node_count, expected_other_node_count,
optimizations_disabled);
std::filesystem::remove(output_model_path);
}

Expand Down Expand Up @@ -317,14 +326,15 @@ TEST_F(QnnHTPBackendTests, CompileApi_DisableEpCompile_ThenCompileExplicitly) {
Ort::ModelCompilationOptions compile_options(*ort_env, so);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(output_model_file);
compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);

// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage();

// Make sure the compiled model was generated and has the expected number of EPContext nodes.
ASSERT_TRUE(std::filesystem::exists(output_model_file));
CheckEpContextNodeCounts(output_model_file, 2, 2);
CheckEpContextNodeCounts(output_model_file, 2, 2, /*optimizations_disabled*/ false);

// Should be able to create a session with the compiled model and the original session options.
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so)));
Expand Down Expand Up @@ -362,7 +372,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) {

// Make sure the compiled model was generated and has the expected number of EPContext nodes.
ASSERT_TRUE(std::filesystem::exists(output_model_file));
CheckEpContextNodeCounts(output_model_file, 2, 2);
CheckEpContextNodeCounts(output_model_file, 2, 2, /*optimizations_disabled*/ true);

// Should be able to create a session with the compiled model and the original session options.
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so)));
Expand Down Expand Up @@ -400,7 +410,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe

// Make sure the compiled model was generated and has the expected number of EPContext nodes.
ASSERT_TRUE(std::filesystem::exists(output_model_file));
CheckEpContextNodeCounts(output_model_file, 2, 2);
CheckEpContextNodeCounts(output_model_file, 2, 2, /*optimizations_disabled*/ true);

// Should be able to create a session with the compiled model and the original session options.
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_file, so)));
Expand Down Expand Up @@ -443,7 +453,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) {
ASSERT_TRUE(output_model_buffer_size > 0);

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2, /*optimizations_disabled*/ true);

{
// Should be able to create a session with the compiled model and the original session options.
Expand Down Expand Up @@ -492,7 +502,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB
ASSERT_TRUE(output_model_buffer_size > 0);

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2, /*optimizations_disabled*/ true);

// Should be able to create a session with the compiled model and the original session options.
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options)));
Expand Down Expand Up @@ -527,7 +537,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB
ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist";

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2, /*optimizations_disabled*/ true);

// Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file
std::string ctx_model = target_dir + model_name;
Expand Down Expand Up @@ -586,7 +596,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu
ASSERT_TRUE(std::filesystem::exists(output_initializers_file));

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2, /*optimizations_disabled*/ true);

// Should be able to create a session with the compiled model and the original session options.
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, so)));
Expand Down
Loading