Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
033a887
Add func typedefs
adrianlizarraga Jul 9, 2025
fcdb5cf
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Jul 9, 2025
03eb5fa
stub apis
adrianlizarraga Jul 15, 2025
3310968
merge main
adrianlizarraga Jul 17, 2025
c3693de
new branch. add 2 streams first
adrianlizarraga Jul 19, 2025
a69d5f9
Move away from using Graph's graph_proto_ member
adrianlizarraga Jul 19, 2025
5743dcd
fix deref assignment
adrianlizarraga Jul 20, 2025
fd87e0c
Clean up
adrianlizarraga Jul 21, 2025
a40f463
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Jul 21, 2025
0dadf4d
Use std::filesystem::path in ModelCompilationOptions; fix memleak in …
adrianlizarraga Jul 21, 2025
d94cf44
fix unused variable warning (as error)
adrianlizarraga Jul 21, 2025
5bfbddb
Merge main and fix conflicts
adrianlizarraga Aug 28, 2025
69a4338
Update handler function signature to take in the ExternalDataInfo for…
adrianlizarraga Aug 28, 2025
90ade82
Add test that reuses external initializers from original model
adrianlizarraga Aug 29, 2025
c36afe5
Define new ExternalDataInfo constructor only for non-minimal builds
adrianlizarraga Aug 29, 2025
c07dc11
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Aug 29, 2025
4b83a2b
Fix unused variable warning (as error)
adrianlizarraga Aug 29, 2025
91acc8f
another unused variable
adrianlizarraga Aug 29, 2025
6e5629a
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Aug 29, 2025
9b092bf
clean up
adrianlizarraga Aug 29, 2025
049b9ad
Start adding csharp api funcs
adrianlizarraga Aug 29, 2025
8e00a06
Remove qnn_factory memleak fix (address in different PR)
adrianlizarraga Aug 29, 2025
11a6c74
Add ExternalInitializerInfo to C++ api
adrianlizarraga Aug 29, 2025
9ca882f
Add compile_to_stream py api
adrianlizarraga Aug 29, 2025
6d522d8
Python bindings and tests
adrianlizarraga Aug 30, 2025
af996bb
C# API for WriteBuffer delegate
adrianlizarraga Aug 31, 2025
9b27b31
c# api handle initializers
adrianlizarraga Aug 31, 2025
9607193
missing documentation in c#
adrianlizarraga Aug 31, 2025
e65710a
Add ExternalInitializerInfo C# class
adrianlizarraga Aug 31, 2025
c16b327
Full C# API for delegate that handles initializers
adrianlizarraga Sep 1, 2025
0b2f0e6
Update comment
adrianlizarraga Sep 2, 2025
83758d1
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Sep 2, 2025
c62ed23
Address review comments
adrianlizarraga Sep 2, 2025
a35e7b6
Address review comments
adrianlizarraga Sep 3, 2025
d906855
Remove unused variable
adrianlizarraga Sep 3, 2025
255c2df
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Sep 3, 2025
3db3117
Merge main conflicts
adrianlizarraga Sep 3, 2025
c7f98de
Merge main again
adrianlizarraga Sep 3, 2025
9031635
Address review comments for C#
adrianlizarraga Sep 3, 2025
abd0297
Rename functions in C and python
adrianlizarraga Sep 3, 2025
d5012fb
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Sep 3, 2025
0e0497a
Address comments
adrianlizarraga Sep 4, 2025
0a61f1f
Merge branch 'main' into adrianl/compile-api-output-stream
adrianlizarraga Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add compile_to_stream py api
  • Loading branch information
adrianlizarraga committed Aug 29, 2025
commit 9ca882f25c5c02ac1de4d40d7624fe5401fb8e2f
10 changes: 9 additions & 1 deletion onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import typing
import warnings
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any

from onnxruntime.capi import _pybind_state as C
Expand Down Expand Up @@ -733,6 +733,14 @@ def compile_to_bytes(self) -> bytes:
"""
return self._model_compiler.compile_to_bytes()

def compile_to_stream(self, write_function: Callable[[bytes], None]):
"""
Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function.
Raises an 'InvalidArgument' exception if the compilation options are invalid.
:param write_function: A callable that accepts a bytes buffer to write.
"""
self._model_compiler.compile_to_stream(write_function)


class IOBinding:
"""
Expand Down
38 changes: 38 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_model_compiler.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]>
// Licensed under the MIT License.
#if !defined(ORT_MINIMAL_BUILD)
#include "python/onnxruntime_pybind_model_compiler.h"

#include <algorithm>
Expand Down Expand Up @@ -77,9 +78,46 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer)
return Status::OK();
}

/**
* Calls the user's Python PyOutStreamWriteFunc function and converts the results to a form that can be used
* by ORT to write out a compiled ONNX model.
*
* @param stream_state Opaque state that holds a pointer to the user's Python function.
* @param buffer The buffer to write out. Contains a portion of the compiled ONNX model's bytes.
* @param buffer_num_bytes The number of bytes to write out.
*
* @return nullptr OrtStatus* to indicate success.
*/
static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer,
size_t buffer_num_bytes) {
PyOutStreamWriteFunc* py_write_func = reinterpret_cast<PyOutStreamWriteFunc*>(stream_state);
OrtStatus* status = nullptr;

// Call the Python write function and convert any exceptions to a status.
ORT_TRY {
pybind11::bytes py_bytes(reinterpret_cast<const char*>(buffer), buffer_num_bytes);
(*py_write_func)(py_bytes);
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()));
});
}

return status;
}

onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) {
model_compile_options_.SetOutputModelWriteFunc(PyOutStreamWriteFuncWrapper,
reinterpret_cast<void*>(&write_func));
ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(env_, model_compile_options_));
return Status::OK();
}

PyModelCompiler::PyModelCompiler(onnxruntime::Environment& env, const PySessionOptions& sess_options,
PrivateConstructorTag)
: env_(env), model_compile_options_(env, sess_options) {
}
} // namespace python
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
18 changes: 16 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_model_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// Licensed under the MIT License.
#pragma once

#if !defined(ORT_MINIMAL_BUILD)
#include <memory>
#include <string>
#include "core/common/status.h"
Expand All @@ -14,11 +13,19 @@ namespace onnxruntime {
class Environment;

namespace python {
// Type of the function provided by Python code that is called by ORT to write out the compiled model.
// Returns the number of bytes written to the underlying stream.
using PyOutStreamWriteFunc = std::function<void(const pybind11::bytes& buffer)>;

/// <summary>
/// Class exposed to Python that enables compiling ONNX models.
/// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings.
/// </summary>
class PyModelCompiler {
#if defined(ORT_MINIMAL_BUILD)
public:
bool not_defined_in_this_build{}; // Prevent empty class warning.
#else
private:
// private tag to pass to constructor to ensure that constructor cannot be directly called externally
struct PrivateConstructorTag {};
Expand Down Expand Up @@ -70,11 +77,18 @@ class PyModelCompiler {
/// <returns>A Status indicating error or success.</returns>
onnxruntime::Status CompileToBytes(std::string& output_buffer);

/// <summary>
/// Compiles the input model and writes the result into the provided output stream (write functor).
/// </summary>
/// <param name="write_func">Write functor that encapsulates the stream's state.</param>
/// <returns>A Status indicating error or success.</returns>
onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func);

private:
onnxruntime::Environment& env_;
onnxruntime::ModelCompilationOptions model_compile_options_;
std::string input_model_bytes_;
#endif // defined(ORT_MINIMAL_BUILD)
};
} // namespace python
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
19 changes: 14 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
#include <mutex>
#include "python/onnxruntime_pybind_exceptions.h"
#include "python/onnxruntime_pybind_mlvalue.h"
#include "python/onnxruntime_pybind_state_common.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "python/onnxruntime_pybind_model_compiler.h"
#endif // !defined(ORT_MINIMAL_BUILD)
#include "python/onnxruntime_pybind_state_common.h"

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
Expand Down Expand Up @@ -2788,7 +2785,19 @@ including arg name, arg type (contains both type and shape).)pbdoc")
ORT_THROW("Compile API is not supported in this build.");
#endif
},
R"pbdoc(Compile an ONNX model into a buffer.)pbdoc");
R"pbdoc(Compile an ONNX model into a buffer.)pbdoc")
.def(
"compile_to_stream",
[](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) {
#if !defined(ORT_MINIMAL_BUILD)
OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func));
#else
ORT_UNUSED_PARAMETER(model_compiler);
ORT_UNUSED_PARAMETER(py_stream_write_func);
ORT_THROW("Compile API is not supported in this build.");
#endif
},
R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc");
}

bool InitArray() {
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python_compile_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,65 @@ def test_compile_from_buffer_to_buffer(self):
self.assertTrue(isinstance(output_model_bytes, bytes))
self.assertGreater(len(output_model_bytes), 0)

def test_compile_from_file_to_stream(self):
"""
Tests compiling a model (from files) to an output stream using a custom write functor.
"""
provider = None
provider_options = dict()
if "QNNExecutionProvider" in available_providers:
provider = "QNNExecutionProvider"
provider_options["backend_type"] = "htp"

input_model_path = get_name("nhwc_resize_scales_opset18.onnx")
output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx")

with open(output_model_path, "wb") as output_fd:
# User's custom write functor. Writes the model to a file.
def my_write_func(buffer: bytes):
self.assertGreater(len(buffer), 0)
output_fd.write(buffer)

session_options = onnxrt.SessionOptions()
if provider:
session_options.add_provider(provider, provider_options)

model_compiler = onnxrt.ModelCompiler(
session_options,
input_model_path,
embed_compiled_data_into_model=True,
external_initializers_file_path=None,
)
model_compiler.compile_to_stream(my_write_func)

self.assertTrue(os.path.exists(output_model_path))

def test_compile_to_stream_that_raises_exception(self):
"""
Tests compiling a model to an output stream that always raises an exception.
"""
input_model_path = get_name("nhwc_resize_scales_opset18.onnx")

# User's custom write functor that raises an exception.
test_py_error_message = "My Python Error"

def my_write_func(buffer: bytes):
self.assertGreater(len(buffer), 0)
raise ValueError(test_py_error_message)

session_options = onnxrt.SessionOptions()
model_compiler = onnxrt.ModelCompiler(
session_options,
input_model_path,
embed_compiled_data_into_model=True,
external_initializers_file_path=None,
)

# Try to compile and expect ORT to raise a Fail exception that contains our message.
with self.assertRaises(Fail) as context:
model_compiler.compile_to_stream(my_write_func)
self.assertIn(test_py_error_message, str(context.exception))

def test_fail_load_uncompiled_model_and_then_compile(self):
"""
Tests compiling scenario:
Expand Down
Loading