Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
new branch. add 2 streams first
  • Loading branch information
adrianlizarraga committed Jul 19, 2025
commit c3693deab5831f9dbc38dd4b1e94e92ec60b4b95
14 changes: 14 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
const std::filesystem::path& model_file_path,
const ModelSavingOptions& model_saving_options) const;

/// <summary>
/// Serialize the Graph to a onnx::GraphProto. Caller provides a function that determines where each initializer
/// is stored (i.e., either in an external file or within the model).
/// </summary>
/// <param name="handle_initializer_func">Function called for every initializer.</param>
/// <param name="state">Opaque user state passed to the handle_initializer_func.</param>
/// <param name="graph_proto">Output parameter set to the serialized onnx::GraphProto.</param>
/// <returns>A status indicating success or an error.</returns>
common::Status ToGraphProtoWithInitializerHandler(OrtHandleInitializerDataFunc handle_initializer_func,
void* state,
/*out*/ ONNX_NAMESPACE::GraphProto& graph_proto) const;

/** Gets the ISchemaRegistry instances being used with this Graph. */
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const;

Expand Down Expand Up @@ -1586,6 +1598,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::ostream& external_stream,
int64_t& external_offset) const;

Status ToGraphProtoWithInitializerHandlerImpl(OrtHandleInitializerDataFunc handle_initializer_func,
void* state, /*out*/ ONNX_NAMESPACE::GraphProto& output_graph_proto) const;
#endif

Version IrVersion() const noexcept {
Expand Down
5 changes: 4 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1159,8 +1159,11 @@ struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath
ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path,
size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile
ModelCompilationOptions& SetOutputModelHandleInitializerFunc(OrtHandleInitializerDataFunc handle_initializer_func,
void* state); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelHandleInitializerFunc
ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
ModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelWriteFunc
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
Expand Down
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,15 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalI
return *this;
}

inline ModelCompilationOptions&
ModelCompilationOptions::SetOutputModelHandleInitializerFunc(OrtHandleInitializerDataFunc handle_initializer_func,
void* state) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelHandleInitializerFunc(this->p_,
handle_initializer_func,
state));
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator,
Expand All @@ -845,6 +854,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func,
void* state) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this->p_, write_func, state));
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
bool embed_ep_context_in_model) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(
Expand Down
26 changes: 22 additions & 4 deletions onnxruntime/core/framework/ep_context_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) {
output_model_location = std::monostate{};
}

output_external_initializers_file_path = config_options.GetConfigOrDefault(
std::string external_initializers_file_path = config_options.GetConfigOrDefault(
kOrtSessionOptionsEpContextModelExternalInitializersFileName, "");
output_external_initializer_size_threshold = 0;
if (!external_initializers_file_path.empty()) {
ExternalInitializerFileInfo ext_info = {};
ext_info.file_path = external_initializers_file_path;
ext_info.size_threshold = 0;
initializers_location = std::move(ext_info);
}

embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1";
}

Expand All @@ -46,6 +52,18 @@ const OutStreamHolder* ModelGenOptions::TryGetOutputModelOutStream() const {
return std::get_if<OutStreamHolder>(&output_model_location);
}

bool ModelGenOptions::AreCpuInitializersEmbedded() const {
return std::holds_alternative<std::monostate>(initializers_location);
}

const ExternalInitializerFileInfo* ModelGenOptions::TryGetExternalInitializerFileInfo() const {
return std::get_if<ExternalInitializerFileInfo>(&initializers_location);
}

const InitializerHandler* ModelGenOptions::TryGetInitializerHandler() const {
return std::get_if<InitializerHandler>(&initializers_location);
}

// class OutStreamBuf

OutStreamBuf::OutStreamBuf(OutStreamHolder out_stream_holder) : out_stream_holder_(out_stream_holder) {
Expand Down Expand Up @@ -91,8 +109,8 @@ int OutStreamBuf::sync() {
Status status = Status::OK();

ORT_TRY {
status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state,
ptr, num_bytes));
status = ToStatusAndRelease(out_stream_holder_.write_func(out_stream_holder_.stream_state,
ptr, num_bytes));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Expand Down
42 changes: 37 additions & 5 deletions onnxruntime/core/framework/ep_context_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,69 @@ struct BufferHolder {
/// Holds the opaque stream state and the write function that ORT calls to write out the output model.
/// </summary>
struct OutStreamHolder {
OrtOutStreamWriteFunc write_func = nullptr;
OrtWriteBufferFunc write_func = nullptr;
void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func.
};

struct ExternalInitializerFileInfo {
std::string file_path;
size_t size_threshold = 0;
};

struct InitializerHandler {
OrtHandleInitializerDataFunc handle_initializer_func = nullptr;
void* state = nullptr;
};

/// <summary>
/// Stores EPContext model generation options. Used in SessionOptions.
/// </summary>
struct ModelGenOptions {
// Action to take if the output model does not have compiled (EPContext) nodes.
enum class ActionIfNoCompiledNodes {
// Return OK() but don't generate an output model. Compiling via SessionOptions defaults to this behavior
// to maintain compatibility. The explicit compile API does *not* use this action.
kDontGenerateModel = 0,

// Generate an output model even if it doesn't have compiled nodes.
// The explicit Compile API defaults to this value.
kGenerateModel,

// Return an error if the model does not have compiled nodes.
// The explicit Compile API can be configured to this value.
kReturnError,
};

ModelGenOptions() = default;

// Initializes from string key/value pairs in session config options.
explicit ModelGenOptions(const ConfigOptions& config_options);

bool enable = false;
bool overwrite_existing_output_file = false;
bool error_if_output_file_exists = true;
bool error_if_no_compiled_nodes = false;
bool embed_ep_context_in_model = false;
ActionIfNoCompiledNodes action_if_no_compiled_nodes = ActionIfNoCompiledNodes::kDontGenerateModel;

std::variant<std::monostate, // Initial state (no output model location)
std::string, // output model path
BufferHolder, // buffer to save output model
OutStreamHolder> // Function to write the output model to a user's stream.
output_model_location{};

std::string output_external_initializers_file_path;
size_t output_external_initializer_size_threshold = 0;

bool HasOutputModelLocation() const;
const std::string* TryGetOutputModelPath() const;
const BufferHolder* TryGetOutputModelBuffer() const;
const OutStreamHolder* TryGetOutputModelOutStream() const;

std::variant<std::monostate, // Initial state (initializers embedded in ONNX model).
ExternalInitializerFileInfo, // Initializers saved in an external file
InitializerHandler> // Custom function called for every initializer to determine location.
initializers_location{};

bool AreCpuInitializersEmbedded() const;
const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const;
const InitializerHandler* TryGetInitializerHandler() const;
};

// Class that wraps the user's OrtOutStreamWriteFunc function to enable use with
Expand Down
Loading
Loading