Skip to content
Merged
19 changes: 13 additions & 6 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5644,14 +5644,10 @@ struct OrtApi {
*/
ORT_API2_STATUS(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name);

/** \brief Get the filepath to the model from which an OrtGraph is constructed.
*
* \note The model's filepath is empty if the filepath is unknown, such as when the model is loaded from bytes
* via CreateSessionFromArray.
/** \brief Returns the ONNX IR version.
*
* \param[in] graph The OrtGraph instance.
* \param[out] model_path Output parameter set to the model's null-terminated filepath.
* Set to an empty path string if unknown.
* \param[out] onnx_ir_version Output parameter set to the ONNX IR version.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
Expand Down Expand Up @@ -6469,6 +6465,17 @@ struct OrtApi {
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
_In_opt_ OrtSyncStream* stream,
_In_ size_t num_tensors);

/** \brief Get ::OrtModelMetadata from an ::OrtGraph
*
* \param[in] graph
* \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out);
};

/*
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2834,6 +2834,7 @@ struct GraphImpl : Ort::detail::Base<T> {
void SetOutputs(std::vector<ValueInfo>& outputs);
void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value
void AddNode(Node& node); // Graph takes ownership of Node
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::Graph_GetModelMetadata
#endif // !defined(ORT_MINIMAL_BUILD)
};
} // namespace detail
Expand All @@ -2848,6 +2849,7 @@ struct Graph : detail::GraphImpl<OrtGraph> {
Graph();
#endif
};
using ConstGraph = detail::GraphImpl<Ort::detail::Unowned<const OrtGraph>>;

namespace detail {
template <typename T>
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2798,6 +2798,13 @@ inline void GraphImpl<OrtGraph>::AddNode(Node& node) {
ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release()));
}

template <typename T>
inline ModelMetadata GraphImpl<T>::GetModelMetadata() const {
OrtModelMetadata* out;
ThrowOnError(GetApi().Graph_GetModelMetadata(this->p_, &out));
return ModelMetadata{out};
}

template <>
inline void ModelImpl<OrtModel>::AddGraph(Graph& graph) {
// Model takes ownership of `graph`
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/graph/abi_graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/tensor_external_data_info.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/inference_session.h"

#define DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(external_type, internal_type, internal_api) \
external_type* ToExternal() { return static_cast<external_type*>(this); } \
Expand Down Expand Up @@ -301,6 +302,11 @@ struct OrtGraph {
/// <returns>The graph's name.</returns>
virtual const std::string& GetName() const = 0;

/// <summary>
/// Returns the model's metadata.
/// </summary>
/// <returns>The model metadata.</returns>
virtual std::unique_ptr<onnxruntime::ModelMetadata> GetModelMetadata() const = 0;
/// <summary>
/// Returns the model's path, which is empty if unknown.
/// </summary>
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/graph.h"
#include "core/graph/model.h"

namespace onnxruntime {

Expand Down Expand Up @@ -769,6 +770,22 @@

const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); }

std::unique_ptr<ModelMetadata> EpGraph::GetModelMetadata() const {
const auto& model = graph_viewer_.GetGraph().GetModel();

Check failure on line 774 in onnxruntime/core/graph/ep_api_types.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'model': references must be initialized

Check failure on line 774 in onnxruntime/core/graph/ep_api_types.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'GetModel': is not a member of 'onnxruntime::Graph'
auto model_metadata = std::make_unique<ModelMetadata>();

model_metadata->producer_name = model.ProducerName();

Check failure on line 777 in onnxruntime/core/graph/ep_api_types.cc

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'model': cannot be used before it is initialized
model_metadata->producer_version = model.ProducerVersion();
model_metadata->description = model.DocString();
model_metadata->graph_description = model.GraphDocString();
model_metadata->domain = model.Domain();
model_metadata->version = model.ModelVersion();
model_metadata->custom_metadata_map = model.MetaData();
model_metadata->graph_name = model.MainGraph().Name();

return model_metadata;
}

const ORTCHAR_T* EpGraph::GetModelPath() const {
return graph_viewer_.ModelPath().c_str();
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ struct EpGraph : public OrtGraph {
// Returns the graph's name.
const std::string& GetName() const override;

// Returns the graph's metadata
std::unique_ptr<ModelMetadata> GetModelMetadata() const override;

// Returns the model path.
const ORTCHAR_T* GetModelPath() const override;

Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/graph/model_editor_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "core/framework/ort_value.h"
#include "core/graph/abi_graph_types.h"
#include "core/graph/onnx_protobuf.h"

#include "core/session/inference_session.h"
namespace onnxruntime {

/// <summary>
Expand Down Expand Up @@ -184,6 +184,9 @@ struct ModelEditorGraph : public OrtGraph {

const std::string& GetName() const override { return name; }

std::unique_ptr<ModelMetadata> GetModelMetadata() const override {
return std::make_unique<ModelMetadata>(model_metadata);
}
const ORTCHAR_T* GetModelPath() const override { return model_path.c_str(); }

int64_t GetOnnxIRVersion() const override {
Expand Down Expand Up @@ -241,6 +244,7 @@ struct ModelEditorGraph : public OrtGraph {
std::vector<std::unique_ptr<onnxruntime::ModelEditorNode>> nodes;
std::string name = "ModelEditorGraph";
std::filesystem::path model_path;
ModelMetadata model_metadata;
};

} // namespace onnxruntime
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2626,6 +2626,16 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out) {
API_IMPL_BEGIN
if (out == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL");
}
*out = reinterpret_cast<OrtModelMetadata*>(graph->GetModelMetadata().release());
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path) {
API_IMPL_BEGIN
if (model_path == nullptr) {
Expand Down Expand Up @@ -4095,6 +4105,8 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::ReleaseSyncStream,

&OrtApis::CopyTensors,

&OrtApis::Graph_GetModelMetadata,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i

// OrtGraph
ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name);
ORT_API_STATUS_IMPL(Graph_GetModelMetadata, _In_ const OrtGraph* graph, _Outptr_ OrtModelMetadata** out);
ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path);
ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version);
ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets);
Expand Down
17 changes: 16 additions & 1 deletion onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,22 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
const ORTCHAR_T* api_model_path = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path));
ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str()));

// Check the model metadata
Ort::AllocatorWithDefaultOptions default_allocator;
auto ort_cxx_graph = Ort::ConstGraph(&api_graph);
auto ort_cxx_model_metadat = ort_cxx_graph.GetModelMetadata();
auto& model = graph_viewer.GetGraph().GetModel();
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetProducerNameAllocated(default_allocator).get(), model.ProducerName().c_str()), 0);
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphNameAllocated(default_allocator).get(), model.MainGraph().Name().c_str()), 0);
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDomainAllocated(default_allocator).get(), model.Domain().c_str()), 0);
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetDescriptionAllocated(default_allocator).get(), model.DocString().c_str()), 0);
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.GetGraphDescriptionAllocated(default_allocator).get(), model.GraphDocString().c_str()), 0);
ASSERT_EQ(ort_cxx_model_metadat.GetVersion(), model.ModelVersion());
auto model_meta_data = model.MetaData();
for (auto& [k, v] : model_meta_data) {
ASSERT_EQ(std::strcmp(ort_cxx_model_metadat.LookupCustomMetadataMapAllocated(k.c_str(), default_allocator).get(), v.c_str()), 0)
<< " key=" << k << "; value=" << v;
}
// Check graph inputs.
const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers();

Expand Down
Loading